Compare commits

...

14 Commits

Author SHA1 Message Date
lukaszraczylo 546ceb949c security: remediate audit findings (ranks 1–16 + 22 Lows) + yaegi load validation (#144)
* fix(security): encrypt session cookies + fail closed on invalid config

Batch 1 of security audit remediation (ranks 1, 2, 6).

- session.go: derive independent HMAC + AES-256 keys via stdlib HKDF-SHA256
  and build the gorilla cookie store with both, so session cookies are now
  encrypted, not merely signed. The single-key store previously left OIDC
  access/refresh/ID tokens recoverable from raw cookie bytes. Cookie format
  changes, so existing sessions are invalidated on deploy (one-time re-login).
- main.go: call config.Validate() at construction and error out on failure,
  instead of silently substituting a public hardcoded encryption key for
  empty/short keys (which allowed session forgery). The yaegi analyzer
  passes via .traefik.yml testData.
- settings.go: isValidSecureURL permits plaintext HTTP for loopback hosts
  only (RFC 8252); remote providers must still use HTTPS.
- tests: complete configs that did not satisfy Validate(); add regression
  tests in security_audit_fixes_test.go.

Configs below documented minimums (rateLimit < 10, key < 32 chars) are now
rejected at startup (fail closed).

* fix(security): validate discovered OIDC endpoints + pin introspection host

Batch 2 of security audit remediation (ranks 3, 4).

- url_helpers.go: add validateDiscoveredEndpoint, an SSRF screen for endpoints
  taken from the provider discovery document (jwks_uri, token, authorization,
  revocation, end_session, introspection, registration). Blocks link-local
  (cloud metadata 169.254.169.254), multicast, unspecified and private
  addresses (unless allowPrivateIPAddresses); blocks loopback unless the
  configured providerURL is itself loopback (dev/test). Cross-domain JWKS
  hosts (e.g. Google) stay allowed. Add sameHost helper.
- main.go: updateMetadataEndpoints screens every discovered endpoint and
  blanks any that fail (fail closed downstream). The introspection endpoint
  carries the client secret via HTTP Basic, so it is additionally pinned to
  the providerURL host to stop a poisoned discovery document exfiltrating the
  secret to an attacker-controlled host.
- tests: regression tests for the SSRF guard and the host pin.

* fix(security): close open redirects + anchor excluded-URL matching

Batch 3 of security audit remediation (ranks 5, 14, 15).

- auth_flow.go: run the stored incoming path through normalizeLogoutPath
  before using it as the post-login redirect, so //evil.com and /\evil.com
  payloads become host-relative (open-redirect, rank 5).
- url_helpers.go: excluded-URL matching is anchored at a natural boundary
  (exact, sub-path "/", or file extension "."), so excluding "/public" no
  longer also bypasses auth on "/publicsecret"; "/favicon" still matches
  "/favicon.ico" (rank 14).
- internal/utils: X-Forwarded-Host is sanitized (first value only; reject
  CRLF/whitespace/multi-value) before building redirect URLs (rank 15).
- helpers.go: the logout redirect used when there is no provider end-session
  endpoint is host-relative, never an absolute URL derived from the
  client-controllable request host (logout open-redirect, rank 15).
- tests: update two logout cases that asserted the old absolute redirect;
  add regression tests.

* fix(security): reject unverified Azure tokens; fix transport TLS reuse

Batch 4 of security audit remediation (ranks 7, 11).

- token_validation_rs.go: an Azure nonce-bearing access token that cannot be
  cryptographically verified no longer returns "authenticated" when there is
  no ID token to corroborate it; it refreshes (if possible) or forces
  re-authentication instead of failing open (rank 7).
- http_client_pool.go: the at-limit transport-reuse path now takes the write
  lock before mutating refCount (fixes a data race) and only reuses a
  transport whose TLS settings (CA pool + InsecureSkipVerify) match the
  caller's, never one with a different trust store; if none matches it returns
  nil so the caller falls back to a verifying default transport (rank 11).
- tests: add a transport-pool TLS-isolation regression test.

* fix(security): stop logging templated header values (token leak)

Batch 5 of security audit remediation (rank 16).

middleware.go: templated downstream headers commonly carry the access token
(e.g. "Authorization: Bearer {{.AccessToken}}"). The debug log line printed
the full header value, leaking credentials into logs. Log the header name and
byte length instead.

* fix(security): cache-key collision, cache-config divergence, fleet cleanup

Batch 6 of security audit remediation (ranks 9, 10, 12).

- token_manager.go: detectTokenType keys its cache on a SHA-256 hash of the
  full token instead of the first 32 chars (which are only the base64url JWT
  header). Distinct tokens sharing alg+kid no longer collide and get
  mis-classified (rank 10).
- cache_manager.go: the process-global cache manager is initialized once and
  shared across plugin instances; it now logs a loud warning when a later
  instance requests a different explicit Redis backend that is silently
  ignored, surfacing the cross-instance state-isolation hazard (rank 9).
- singleton_resources.go / main.go / utilities.go: track a process-global live
  instance count; the shared singleton-token-cleanup task is stopped only when
  the LAST instance shuts down, so one instance's Close() (e.g. a config reload)
  no longer kills cleanup for surviving instances (rank 12).
- tests: update TestDetectTokenTypeCaching for the new key; add regression tests.

* fix(security): bound introspection cache + cookie lifetime to config

Batch 7 of security audit remediation (ranks 8, 13).

- token_introspection.go: when requireTokenIntrospection is enabled, cap the
  positive introspection-result cache at 30s (instead of 5m) so a token
  revoked at the provider stops passing within ~30s, matching the operator's
  near-real-time revocation expectation (rank 8).
- session.go: bind the cookie store's MaxAge to the configured sessionMaxAge,
  so the cookie codec's cryptographic timestamp validity is no longer fixed at
  gorilla's 30-day default; a stolen cookie is valid only for the configured
  session lifetime (rank 13).
- tests: add a cookie-lifetime regression test.

* fix(security): low-severity hardening (cache, DoS caps, PKCE, throttle)

Batch 8 of security audit remediation — low severity
(ranks 24, 25, 27, 29, 31, 36, 37, 41, 45, 46, 49).

- universal_cache.go: updateLocalCache updates an existing key in place instead
  of orphaning its LRU element and double-counting currentSize/currentMemory
  (rank 36 — the only production-reachable bug in this batch).
- jwk.go / metadata_cache.go / token_introspection.go: bound response bodies
  with io.LimitReader (1 MiB) to prevent memory exhaustion from a hostile or
  buggy provider (ranks 24, 25).
- jwk.go: skip JWKs not usable for signature verification (use != sig, or
  key_ops without "verify") when building the key set (rank 49).
- auth_flow.go: fail closed at the callback when PKCE is enabled but the code
  verifier is missing, instead of silently dropping it (rank 27).
- utilities.go / main.go: match allowedUserDomains case-insensitively (rank 31).
- bearer_auth.go: a single success no longer wipes an active per-IP penalty;
  the counter resets only when no penalty is in effect (rank 29).
- main.go: handle (not discard) the NewSessionManager error (rank 37).
- error_recovery.go: take a write lock in isServiceDegraded (it deletes from a
  map); compare retryable-error substrings case-insensitively (ranks 45, 46).
- singleton_resources.go: bind the generic-cache cleanup goroutine to the
  resource-manager shutdown channel so it cannot outlive its owner (rank 41).
- tests: update the bearer throttle test to the corrected penalty semantics.

* fix(security): header sanitization, issuer pinning, fail-closed paths

Batch 9 of security audit remediation (ranks 18, 19, 20, 21, 22, 30, 33, 34).

- middleware.go / bearer_auth.go: sanitize claim-derived values on the cookie
  auth path before injecting them into downstream headers. Drop group/role and
  identifier values containing control chars, bidi-override runes, or the
  , ; = delimiters (a comma would inject phantom entries into X-User-Groups);
  reject control/bidi/over-length in rendered templated header output (but
  permit , ; = in free-form values such as a bearer token). The bearer path
  already sanitized; the cookie path did not (ranks 33, 34).
- main.go / metadata_cache.go: pin the discovered issuer to the configured
  provider host (sameHost) and refuse/never-cache a mismatch, so a poisoned
  discovery document cannot redefine the JWT trust anchor (ranks 21, 22).
- token_introspection.go: when a distinct API audience is configured, fail
  closed on a missing or mismatched introspection audience; aud parsed as
  string-or-array per RFC 7662 (rank 19).
- logout.go: front-channel logout requires a matching issuer; an empty iss is
  rejected (blocks unauthenticated forced-logout via a known sid) (rank 30).
- token_validation_rs.go: an opaque access token with no ID token and no
  successful introspection fails closed (re-auth) instead of authenticating
  (ranks 18, 20).
- tests: realistic same-host provider mocks; regression tests for the header
  sanitization distinction and the fail-closed paths.

* chore(security): remove unwired dead code with latent footguns

Batch 10 of security audit remediation — delete confirmed-dead, unwired
subsystems (ranks 26, 35, 50). None had a production caller (grep-verified);
removal eliminates the latent footguns and ~2.1k lines of dead code.

- token_validator.go (deleted): an unused *TokenValidator whose validateJWT set
  Valid=true with NO signature verification — a severe footgun if ever wired
  (rank 50). The wired RS-aware validators are unaffected.
- security_monitoring.go (deleted): an unused *SecurityMonitor / ExtractClientIP
  that trusted spoofable X-Forwarded-For / X-Real-IP. The live bearer throttle
  uses clientIPForBearer (RemoteAddr-only), unchanged (rank 35).
- dynamic_client_registration.go: removed the RFC 7592 management methods
  (Update/Read/DeleteClientRegistration) that dereferenced an attacker-
  influenced RegistrationClientURI with the registration token attached and no
  HTTPS/SSRF gate, and had no callers. The wired RFC 7591 RegisterClient and
  credential-store helpers are kept (rank 26).
- tests: removed the tests covering the deleted code.

* chore: add Makefile with yaegi load validation

No Makefile existed. The new `yaegi-validate` target interprets the plugin
under the yaegi interpreter the same way Traefik loads it, catching yaegi-only
incompatibilities (unsupported stdlib symbols, reflection edge cases) that the
native `go build` / `go test` toolchain does not. Importing the plugin forces
yaegi to interpret every file plus its vendored deps; CreateConfig + New
exercise the instantiation path.

- cmd/yaegicheck/main.go: the load driver, marked //go:build ignore so it is
  excluded from `go build ./...` (avoids VCS-stamping a main binary, which
  fails in git-worktree layouts) yet is run explicitly by yaegi.
- Makefile: build / fmt / vet / lint / test / vendor / yaegi-validate / check
  targets; `make check` runs vet + tests + yaegi-validate.

Verified: `make yaegi-validate` passes on this branch — the HKDF cookie
encryption, net-based endpoint validation, and claim sanitizers all interpret
and instantiate cleanly under yaegi.

* ci: bump workflow Go toolchain to 1.25; pin yaegi-validate to v0.16.1

Traefik v3.7.1 (the deployed version) is built with `go 1.25.0`, so the PR and
release workflows now use Go 1.25.x to match the toolchain Traefik uses.

Important distinction: the CI Go version is the build TOOLCHAIN. The plugin's
actual interpreter-compatibility ceiling is the yaegi version Traefik bundles
(v0.16.1, which declares go 1.21 and ships a ~Go 1.22 stdlib symbol surface),
NOT the CI Go version. That ceiling is enforced by `make yaegi-validate` plus
the go.mod language directive — e.g. it is why HKDF is hand-rolled with
hmac+sha256 rather than Go 1.24's crypto/hkdf, which yaegi v0.16.1 lacks.

Also pin Makefile YAEGI_VERSION to v0.16.1 (what Traefik v3.7.1 vendors) so
yaegi-validate exercises the real deployed interpreter instead of @latest,
which could pass on a newer yaegi that supports symbols the deployed one does
not.

* docs: align README/CONFIGURATION with branch behavior changes

- excludedURLs: documented as segment/extension-boundary matching (was
  "prefix-matched") — "/public" no longer also matches "/publicsecret" (rank 14).
- Front-channel logout now requires a matching `iss`; requests without one are
  rejected with 400 (rank 30).
- Add an "Upgrading from an earlier release" note: session cookies are now
  AES-256 encrypted with lifetime tracking sessionMaxAge (one-time re-login on
  upgrade), and invalid configuration (rateLimit < 10, key < 32 bytes, missing
  callbackURL, non-HTTPS remote providerURL) now fails closed at startup.

* fix: remove staticcheck-flagged unused functions; wire staticcheck into make check

CI Static Analysis (standalone staticcheck) failed with U1000 "unused":
- dynamic_client_registration.go: deleteCredentialsFromStore — its only caller
  was the RFC 7592 DeleteClientRegistration removed in the dead-code batch.
- token_test.go: createTestJWTSimple — its only callers were the TokenValidator
  tests removed in the same batch.
Both confirmed to have zero remaining callers and removed. build / vet /
go test ./... / staticcheck ./... all green.

The pre-commit hook runs golangci-lint, but CI runs standalone staticcheck
(which flags U1000). Add a `staticcheck` Makefile target and include it in
`make check` so this class of finding is caught locally before push.

* fix(test): stabilize flaky TestWorkerPool_TaskPanic

tasksFailed is incremented in the worker's deferred recover(), which runs after the panicking task's own defer wg.Done(). wg.Wait() could therefore return before the failure was recorded, so reading the counter immediately raced and flaked on slow CI runners. Poll until the failure lands (2s budget) instead. Verified 200x plain + 50x under -race/GOMAXPROCS=1.
2026-05-30 14:10:32 +01:00
lukaszraczylo f75b2f20e0 fix: resolve cache eviction lock-up and migrate telemetry [patch-release]
universal_cache: stop the write-lock convoy / 100%-CPU spin (observed via pprof: one ServeHTTP goroutine holding c.mu.Lock for hours while 119 requests queued). The per-request populate path (updateLocalCache) PushFronted a duplicate LRU node + overwrote items[key] without removing the prior node; once eviction deleted the key, orphan nodes at Back() were never removable and the eviction loop spun forever under the write lock. Replace the entry in place (mirroring setLocal) and harden evictOldest with a forward-progress guard. Adds universal_cache_orphan_test.go.

telemetry: delete the hand-rolled client; call oss-telemetry v0.2.3 (vendored, Yaegi-safe) directly from New(), once per process via sync.Once.

version: add version.go + workflow-prepare.sh so the release semver is stamped into source at build time (the value cannot be resolved at runtime under Yaegi). dev/source builds keep the 0.0.0-dev sentinel and emit no telemetry.
2026-05-30 13:22:03 +01:00
paiking1 cf6ed1da55 feat: feat: add extraAuthParams (extra authorization request parameters) (#139)
Adds optional extraAuthParams map[string]string config.

Extra params are appended to the authorization request but can never
override plugin-managed params (client_id, state, nonce, etc.).
2026-05-27 21:41:09 +01:00
lukaszraczylo f821b8829b fix: remove write-lock convoy in getLocal + fix mutateState CAS bug
UniversalCache.getLocal(): when a cached token expires, the RLock fast
path (line 385-398) previously fell through to c.mu.Lock() (write lock).
Under Yaegi, the write-lock holder takes 10-100ms for LRU manipulation,
and Go's RWMutex writer-priority blocks ALL new RLock callers. A single
expired-token event turned every concurrent request from read-parallel
into write-serialized — the convoy that produced the 737-goroutine
pileup at 0x400275a608 (pprof captured at /tmp/traefik-spike-1779663149).

Fix: return (nil, false) immediately on expiry for Token/JWK/Session
cache types. The periodic cleanup goroutine handles eviction. Write lock
is never taken on the read path for these cache types.

refreshAttemptTracker.mutateState(): the CAS loop used
t.state.CompareAndSwap(t.state.Load(), next) — a second Load that can
see a different value from a concurrent writer, silently overwriting
their update. Fixed to CompareAndSwap(cur, next) using the snapshot we
computed the mutation from.
2026-05-25 00:06:47 +01:00
lukaszraczylo 5f9c574f95 refactor: delete dead non-RS validators; tests use RS variants
After v1.0.20 the non-RS validation chain had no production callers —
middleware.ServeHTTP dispatched exclusively through isUserAuthenticatedRS.
The orphaned functions stayed reachable only from a handful of test
files and risked silent logic drift against their RS counterparts.

Deleted from production code (~440 LOC):
  - auth_flow.go:        isUserAuthenticated
  - token_manager.go:    validateAzureTokens
  - token_manager.go:    validateGoogleTokens
  - token_manager.go:    validateStandardTokens
  - token_manager.go:    validateTokenExpiry
  - removed now-unused encoding/base64 and encoding/json imports
    from token_manager.go (only the deleted validateStandardTokens
    needed them; the RS variant in token_validation_rs.go keeps its
    own imports).

Added (3 LOC):
  - token_validation_rs.go: validateGoogleTokensRS (trivial delegator,
    parity with the deleted non-RS variant so isUserAuthenticatedRS
    can dispatch cleanly).

Tests ported (10 call sites across 3 files):
  - audience_test.go:                ts.tOidc.validateStandardTokens
  - azure_oidc_test.go:              tOidc.validateAzureTokens,
                                     ts.tOidc.validateGoogleTokens,
                                     ts.tOidc.validateAzureTokens,
                                     ts.tOidc.isUserAuthenticated
  - issue134_followup_graph_test.go: oidc.validateAzureTokens (4x)

Each ported site now constructs a *requestState from its existing
*SessionData via (&requestState{}).captureSession(session) and calls
the *RS variant. Same data, different read source.

Net diff: -440 LOC production, ~+25 LOC tests, +3 LOC stub.
Production now has a single source of truth for token validation;
no parallel implementations to keep in sync.

All tests pass with -race; golangci-lint clean.
2026-05-23 13:04:26 +01:00
lukaszraczylo 7c6f09fb20 feat(middleware): RS-aware token validators (kill ~21 RLocks/request)
Adds token_validation_rs.go with requestState-aware variants of the
token validation path:

  isUserAuthenticatedRS(rs) -> dispatches by provider
    validateStandardTokensRS(rs) -> standard path (eliminates 17 RLocks)
    validateAzureTokensRS(rs)    -> Azure path     (eliminates 10 RLocks)
    validateGoogleTokensRS(rs)   -> delegates to standard
  validateTokenExpiryRS(rs, tok) -> shared expiry check (eliminates 4 RLocks)

middleware.ServeHTTP now calls isUserAuthenticatedRS(rs) on the hot
path. The pre-v1.0.20 non-RS variants are kept untouched for tests
and any future caller that doesn't have a captured snapshot.

Why
---
The standard validation path read SessionData via session.GetX() 17
times, with GetRefreshToken alone called 11 times (every "return
'needs refresh'" branch re-reads it). Each call acquires
sd.sessionMutex.RLock(). Under Yaegi each RLock costs ~1-5ms of
interpreter dispatch. The captured snapshot already lives on rs, so
the RS variants substitute direct struct field reads.

Per-request cost on the hot authenticated path
----------------------------------------------
  ServeHTTP enters:
    + 1 RLock to populate rs (was 0)
  Validation path:
    Standard: was 17 RLocks, now 0
    Azure:    was 10 RLocks, now 0
  processAuthorizedRequestRS:
    was 4-6 GetX calls, now 0 (already in v1.0.19)

Net: ~22-27 fewer Yaegi-dispatched RLock acquisitions per authenticated
request on the hot path.

Caveats
-------
* Refresh / expired / callback paths still use the non-RS validators
  because they can mutate session state between validation and use.
* The RS variants are by-design line-for-line equivalents of the
  originals. If logic in the originals changes, the RS variants need
  matching updates. This is acceptable for now; a future refactor
  could collapse them once the non-RS callers are gone.

All tests pass with -race; golangci-lint clean.
2026-05-23 12:38:42 +01:00
lukaszraczylo 68e1c4319c feat(middleware): per-request context object (requestState)
Adds requeststate.go and threads a *requestState through the
ServeHTTP -> processAuthorizedRequestRS -> forwardAuthorized path.
rs is allocated once at the top of ServeHTTP, populates SessionData
field snapshots under a SINGLE sd.sessionMutex.RLock, and caches the
MetadataSnapshot. Downstream handlers read the cached fields instead
of calling session.GetX() / t.metadataSnap() repeatedly.

Why
---
Under Yaegi each method dispatch (including RWMutex.RLock) costs
~1-5ms of interpreter overhead. SessionData getters each take an
RLock on sd.sessionMutex; the previous hot path called 5-7 of them
per request (GetAuthenticated, GetAccessToken, GetIDToken,
GetRefreshToken, GetUserIdentifier, plus the same set again inside
processAuthorizedRequest). With one batched RLock + cached fields,
that drops to a single RLock for the whole handler chain.

This is scoped — not a wholesale architectural refactor:

* requestState is per-request (alloc at ServeHTTP entry, dropped on
  return). It is NOT a shared cache and never escapes the request.
* The original processAuthorizedRequest is kept unchanged for any
  callers we don't migrate this round (bearer path, callback
  handlers, expired-token handlers). New code path is the RS-aware
  processAuthorizedRequestRS, which middleware.ServeHTTP now uses for
  the happy authenticated-and-not-needing-refresh case.
* Cross-request caches (tokenCache, JWKCache, sessionEntries,
  sessionInvalidationCache) are unchanged. rs is additive, not a
  replacement.

What this does NOT change
-------------------------
* The refresh path still calls session.GetX() in middleware.go
  (handleExpiredToken, refreshToken, defaultInitiateAuthentication)
  because those flows can mutate session state and a stale rs would
  be wrong.
* validateStandardTokens still has its own session.GetX() calls.
  Deep plumbing into the token-verification path is a follow-up.
* No semantic changes to authentication, refresh, or session
  lifecycle — only the read path is optimised.

All tests pass with -race; golangci-lint clean.
2026-05-23 12:22:51 +01:00
lukaszraczylo 17e3f8ef62 fix: snapshot patterns for refresh-tracker and metadata URLs
Two related lock-free snapshot refactors addressing the remaining
post-v1.0.16 code-review findings.

1. refreshAttemptTracker: per-field atomic.Load/Store -> atomic.Value
   snapshot of *attemptState (refresh_coordinator.go).

   Previously each tracker held five independently-atomic fields. The
   cooldown-exit reset wrote cooldownEndNano = 0 first, then separately
   stored attempts = 1 and windowStartNano = now. A concurrent
   isInCooldown call could observe cooldownEndNano = 0 (reset just
   completed) with attempts still at MaxRefreshAttempts, immediately
   triggering a fresh cooldown — a benign double-trigger race that
   nonetheless meant the state machine had observable intermediate
   states.

   New design: state is a *attemptState (immutable) published via
   atomic.Value. All transitions (record/success/failure/window-reset/
   cooldown-enter/cooldown-exit) go through mutateState, which runs a
   CAS loop: load current snapshot -> construct fresh snapshot ->
   CompareAndSwap. Either the entire new state publishes or none of
   it does — no intermediate visibility, no cross-field race.

   Under Yaegi this collapses 3-5 per-field atomic dispatches into one
   atomic.Value.Load on the read path. Write paths pay an extra
   allocation for the new snapshot but avoid the cross-field hazard.

2. MetadataSnapshot: hot-path readers use atomic.Value instead of
   metadataMu.RLock (middleware.go, types.go, main.go, utilities.go).

   middleware.ServeHTTP previously took metadataMu.RLock on every
   non-bypass request to read the single field issuerURL. Under Yaegi
   each RLock acquisition costs 1-5ms of interpreter dispatch.
   updateMetadataEndpoints now also publishes an immutable
   *MetadataSnapshot via atomic.Value; the hot-path reader loads it
   in one op via t.metadataSnap(). Falls back to the legacy
   metadataMu.RLock pattern when the snapshot is unpublished (some
   test setups initialize the struct fields directly without going
   through updateMetadataEndpoints).

   Less-frequent callers (helpers, logout, token_introspection) still
   take metadataMu.RLock and are unchanged. The snapshot strictly
   subsets the metadataMu-protected fields, so those readers see
   identical data.

Note on atomic.Pointer[T]: this would have been the cleaner type but
yaegi v0.16.1's stdlib (used by traefik:v3.7.1) exposes only the
legacy unsafe.Pointer-based atomic primitives — no generic Pointer[T].
atomic.Value provides the same semantics via interface{} + type assert.

All tests pass with -race; golangci-lint clean.
2026-05-23 11:31:51 +01:00
lukaszraczylo 827926bc3a fix(refresh-coordinator): trim per-request mutex/map ops
Three related changes addressing post-v1.0.15 code-review findings and
the user's observation that we have been "throwing maps around" — under
Yaegi every sync.Map / atomic / mutex dispatch costs ~1-5ms of
interpreter overhead, so the number of dispatches per request matters
as much as whether they are lock-free.

1. Remove cleanupTimers map + cleanupTimerMu sync.Mutex.

   scheduleDelayedCleanup previously tracked every pending timer in a
   map guarded by a mutex so a duplicate scheduling could cancel the
   prior timer. That "shouldn't happen" path was the only consumer of
   the map, but the mutex fired on every successful refresh
   completion — another per-request Yaegi-dispatched lock.
   performCleanup is already idempotent (LoadAndDelete on the sync.Map),
   so a duplicate firing is at worst a no-op second call. Dropped the
   map entirely; time.AfterFunc callback now calls performCleanup
   directly.

   Net: -1 sync.Mutex, -1 map field, -2 Lock/Unlock pairs per refresh
   completion. Shutdown simplified — no need to enumerate-and-stop
   timers since the callbacks no longer need teardown.

2. Reorder applyLeaderGates: cooldown check BEFORE recordRefreshAttempt.

   Previously incremented the attempt counter and then checked cooldown.
   Under burst load (many concurrent leaders with different token hashes
   but the same session) every goroutine could increment past
   MaxRefreshAttempts before any one of them observed the threshold,
   so the gate fired too late — same thundering-herd shape that drove
   v1.0.14 into the ground. Reordering makes the gate authoritative:
   only attempts that pass the gate are recorded.

   Semantic change: with MaxRefreshAttempts=N, exactly N attempts now
   run to completion before the (N+1)th is denied. Previously the Nth
   was denied as it tried to record (off-by-one stricter). Test
   assertion updated to N (was N-1).

3. Fix getOrCreateOperation MaxConcurrentRefreshes overshoot.

   The previous CAS-loop allowed a transient overshoot of up to N-1
   leaders when several goroutines all observed `current < max` in the
   same scheduling slice before any one of them succeeded their CAS —
   visible to readers as currentInFlightRefreshes > MaxConcurrentRefreshes
   for a brief window.

   Replaced with the ticket-and-return pattern: increment optimistically,
   decrement if we overshot. Strictly bounded: only the goroutine that
   produces max+1 sees max+1 as committed; the rest decrement back
   immediately. No CAS retry loop needed.

What was NOT done in this commit, and why:

* metadataMu.RLock cached via atomic snapshot — code-reviewer flagged
  this at severity 7 (3 RLocks per request: middleware.go:213,
  token_manager.go:349, token_manager.go:408). The clean fix is an
  atomic.Pointer[*MetadataSnapshot], but generic atomic.Pointer[T] is
  NOT exposed by yaegi v0.16.1's stdlib (only legacy unsafe.Pointer
  primitives). atomic.Value would work but requires a snapshot-struct
  refactor across ~15 call sites (helpers/logout/token_introspection/
  token_manager/main/middleware). Deferred to a focused future PR.

* isInCooldown multi-field reset race — the cooldown-reset CAS wins
  on cooldownEndNano, then separately stores attempts/consecutiveFailures/
  windowStartNano. A concurrent isInCooldown can briefly see the
  pre-reset attempts value and trigger a fresh cooldown. Semantic glitch
  (double-cooldown), not a correctness disaster. Fix is a single atomic
  pointer swap of an immutable snapshot — same atomic.Pointer constraint
  as above. Deferred.

All tests pass with -race; golangci-lint clean.
2026-05-23 11:23:16 +01:00
lukaszraczylo abbfdb02a7 fix(jwk): replace JWKCache.mutex with singleflight pattern
JWKCache.GetJWKS previously held a sync.RWMutex.Lock() across the entire
HTTP round-trip to the IdP's JWKS endpoint (jwk.go:93). On a cold cache
(cold pod, JWK rotation, transient network blip) every concurrent
request piled up on this single global write-lock. Under Yaegi each
Lock() acquisition costs 10-50ms of interpreter dispatch — same
architectural shape as the bugs v1.0.14 and v1.0.15 already fixed,
just one that hadn't surfaced as the dominant bottleneck yet.

Code-review post-spike #2 flagged this at confidence 9/10 as the next
likely death-spiral on pod cold-start.

Change replaces the lock with a sync.Map-based singleflight: the first
caller for a given JWKS URL performs the fetch; concurrent callers
attach to the same *jwksFetch and wait on its done channel for the
result. Cold-cache cost is now O(1) HTTP fetch regardless of how many
goroutines are waiting, and no Yaegi-dispatched lock is held during the
fetch itself.

Correctness:
- LoadOrStore winner does the work; losers wait on a done channel.
- Done channel close is in a defer, so panics in fetchJWKS still
  unblock waiters.
- Map entry is removed in the same defer, so a fresh failed fetch can
  be retried by the next request without waiting for any stale entry.
- ctx.Done() unblocks waiters independently of the leader's progress.
- Re-checks the cache after winning LoadOrStore, since another fetch
  may have populated the cache between the initial miss and the win.

Cleanup: also removes a stray yaegi-extract output file
(github_com-lukaszraczylo-traefikoidc.go) that was accidentally
committed during local yaegi compatibility testing.

All tests pass with -race; golangci-lint clean.
2026-05-23 11:05:24 +01:00
lukaszraczylo 72e2b682bb fix: eliminate per-request global mutexes in Yaegi hot paths
The v1.0.14 fix replaced one contended sync.RWMutex (RefreshCoordinator.
refreshMutex) with sync.Map. Production showed the same death-spiral
signature recurring ~2 hours later — same shape, different mutex:
65 goroutines stuck on a sync.(*RWMutex).Lock at one address, pod
pinned at 1000m CPU, identical Yaegi runCfg/reflect.Value.Call stack
pattern. The mutex was RefreshCoordinator.attemptsMutex.

Generalising: under Yaegi (interpreted Go for traefik plugins), any
per-request global mutex acquisition is a latent serialization point.
reflect.Value.Call dispatch on a held lock turns a microsecond
critical section into a multi-millisecond one, and on a GOMAXPROCS=1
pod the queue is unbounded.

This commit removes every per-request global mutex on the hot path:

1. RefreshCoordinator.attemptsMutex (sync.RWMutex)
   sessionRefreshAttempts: map -> sync.Map.
   refreshAttemptTracker: all fields atomic (int32, int64 UnixNano,
   cooldownEndNano == 0 as the not-in-cooldown sentinel, replacing
   the inCooldown bool).
   isInCooldown / recordRefreshAttempt / recordRefreshSuccess /
   recordRefreshFailure all become lock-free. Cooldown entry uses
   CompareAndSwapInt64 so only one goroutine logs the transition.

2. RefreshCircuitBreaker.mutex (sync.RWMutex)
   lastFailureTime / lastSuccessTime -> atomic.Int64 UnixNano.
   state and failures already atomic.
   AllowRequest / RecordSuccess / RecordFailure now pure atomic ops.

3. TraefikOidc.firstRequestMutex (sync.Mutex)
   firstRequestReceived bool -> firstRequestStarted int32.
   metadataRefreshStarted bool -> metadataRefreshStartedAtomic int32.
   ServeHTTP bootstrap path uses CompareAndSwapInt32 — fires once,
   zero steady-state cost. Previously the mutex was acquired on
   every non-health request forever.

4. TraefikOidc.metadataRetryMutex (sync.Mutex)
   lastMetadataRetryTime time.Time -> lastMetadataRetryNano int64.
   The 30-second retry throttle is now a CAS on lastMetadataRetryNano.

cleanupStaleEntries iterates via sync.Map.Range; eviction is a
CompareAndDelete by pointer identity so a tracker freshly re-used by
a concurrent caller is not lost.

Empirical evidence (3 specialist-agent analysis of the v1.0.14 spike,
profiles in /tmp/traefik-spike-1779511683/):
  * mutex profile: 97% delay in sync.(*Mutex).Unlock via
    HTTPHandlerSwitcher -> accesslog -> metrics -> backoff.RetryNotify
  * 65 stuck goroutines at one RWMutex address (0x40022eb648),
    identical Yaegi CFG pointer, all on rc.attemptsMutex via
    recordRefreshAttempt + isInCooldown
  * traffic driver: long-lived in-cluster Go-http-client doing
    ~5.4 req/s POST embeddings via OIDC cookie session → same
    sessionID → contention all funnels to one tracker entry

Yaegi support for sync/atomic confirmed at
github.com/traefik/yaegi@v0.16.1/stdlib/go1_22_sync_atomic.go:
AddInt32/Int64, LoadInt32/Int64, StoreInt32/Int64,
CompareAndSwapInt32/Int64 all exposed via reflect.ValueOf. Yaegi
dispatches each call through reflect.Value.Call to the COMPILED
atomic.* function, which executes a single hardware CAS/LOCK-XADD
instruction. Each atomic op still pays Yaegi dispatch cost but
cannot block — no queueing, no death spiral.

Trade-off acknowledged: v1.0.15 issues ~6-8 atomic/sync.Map ops per
leader-path request vs the 4 mutex ops of v1.0.14. Under low
contention this is a modest CPU bump. Under high contention it's
an unbounded → bounded transformation. Net win.

All tests pass with -race; golangci-lint clean.
2026-05-23 10:47:21 +01:00
lukaszraczylo ae4ccaa89d fix(refresh-coordinator): replace global RWMutex with sync.Map
Under Yaegi, the RefreshCoordinator.refreshMutex was held for tens of
milliseconds per request because every operation inside the critical
section (map access, isInCooldown, recordRefreshAttempt,
isUnderMemoryPressure, atomic ops, struct allocation) is dispatched
through reflect.Value.Call with full arg boxing/unboxing.

Concurrent refreshes on the same coordinator serialized into a queue
that grew without bound. Live capture in production (3 Grafana
dashboards left open) showed:
  * 63 goroutines stuck on rc.refreshMutex.Lock() for 1-11 minutes
  * pod pinned at 1000m CPU (GOMAXPROCS=1)
  * 5.15M allocs/sec, 0.45 RPS effective throughput
  * yaegi.call.func9 accounting for 92.66% of cumulative allocs
  * mutex profile dominated by sync.(*Mutex).Unlock via the request chain

Change inFlightRefreshes from map[string]*refreshOperation+RWMutex to
sync.Map and rewrite getOrCreateOperation to:
  1. Speculatively allocate the candidate operation.
  2. Atomically LoadOrStore by tokenHash. Joiners take the existing
     operation; leader takes the new one. No global lock acquired.
  3. Leader runs rate-limit / cooldown / memory-pressure gates AFTER
     the atomic store. Joiners share the leader's outcome via op.done.
  4. Reserve the concurrent-refresh slot via CompareAndSwap so the
     count cannot overshoot in absence of the old serializing lock.
  5. On any gate failure the leader calls failCandidate, which deletes
     the entry from sync.Map, records the error on op.result and closes
     op.done so any joiner that snuck in returns the same error.

performCleanup becomes a single sync.Map.LoadAndDelete, eliminating
the lock entirely on the cleanup path.

Net effect: critical section is no longer Yaegi-interpreted; it
collapses to atomic instructions on a sharded sync.Map. Refresh
contention disappears even under Yaegi.

All tests pass with -race; golangci-lint clean.
2026-05-23 02:34:49 +01:00
lukaszraczylo 984fd1c08f docs: add Telemetry section linking to oss-telemetry opt-out docs
Discloses the single anonymous adoption ping sent on first plugin
instantiation. Points users to the upstream README section for the
disclosure pattern and to the local telemetry.go for the inline
implementation.
2026-05-21 04:07:19 +01:00
lukaszraczylo 99bdd23986 feat: anonymous usage telemetry via inline oss-telemetry
Adds a yaegi-safe inline telemetry helper that fires a single
fire-and-forget ping at plugin load. Helps track adoption and version
spread. No persistent identifiers are collected.

Implementation notes:
- inline (no external dep) so Traefik plugin loader does not need to
  resolve a new vendored module
- stdlib-only, no generics, no range-over-int — verified to load under
  yaegi 0.16.x (full plugin import + CreateConfig/New symbol lookup OK)
- avoids `switch{case A,B,C:}` blocks where some yaegi releases
  mis-evaluate comma-separated case lists
- sync.Once guards against amplified pings on Traefik dynamic config
  reloads (which re-instantiate the middleware)

Opt out via any of:
  DO_NOT_TRACK=1
  OSS_TELEMETRY_DISABLED=1
  TRAEFIKOIDC_DISABLE_TELEMETRY=1
2026-05-21 03:20:36 +01:00
65 changed files with 3392 additions and 3283 deletions
+1 -1
View File
@@ -18,6 +18,6 @@ jobs:
pr-checks:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with:
go-version: "1.24.11"
go-version: "1.25.x"
coverage-threshold: 70
secrets: inherit
+1 -1
View File
@@ -19,5 +19,5 @@ jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.24.11"
go-version: "1.25.x"
secrets: inherit
+61
View File
@@ -0,0 +1,61 @@
# traefikoidc — Makefile
# Run `make help` for available targets.
GO ?= go
GOPATH := $(shell $(GO) env GOPATH)
# Pin to the yaegi version bundled by the deployed Traefik so yaegi-validate
# tests the real interpreter, not a newer one that may support more. Traefik
# v3.7.1 vendors yaegi v0.16.1 (Go ~1.22 stdlib surface). Bump when Traefik is.
YAEGI_VERSION ?= v0.16.1
TEST_TIMEOUT ?= 480s
.DEFAULT_GOAL := help
.PHONY: help
help: ## Show this help
@grep -hE '^[a-zA-Z0-9_-]+:.*## ' $(MAKEFILE_LIST) | awk 'BEGIN{FS=":.*## "}{printf " \033[36m%-16s\033[0m %s\n", $$1, $$2}'
.PHONY: build
build: ## Compile all packages (native toolchain)
$(GO) build ./...
.PHONY: fmt
fmt: ## Format sources with gofmt
gofmt -w $$(git ls-files '*.go' | grep -v '^vendor/')
.PHONY: vet
vet: ## Run go vet
$(GO) vet ./...
.PHONY: lint
lint: ## Run golangci-lint if available
@command -v golangci-lint >/dev/null 2>&1 && golangci-lint run ./... || echo "golangci-lint not installed; skipping"
.PHONY: staticcheck
staticcheck: ## Run staticcheck (matches the CI "Static Analysis" job; catches U1000 unused, etc.)
@command -v staticcheck >/dev/null 2>&1 || { echo ">> installing staticcheck"; $(GO) install honnef.co/go/tools/cmd/staticcheck@latest; }
@GOFLAGS=-buildvcs=false $$(command -v staticcheck || echo "$(GOPATH)/bin/staticcheck") ./...
.PHONY: test
test: ## Run the test suite
$(GO) test ./... -count=1 -timeout $(TEST_TIMEOUT)
.PHONY: vendor
vendor: ## Refresh and vendor dependencies
$(GO) mod tidy && $(GO) mod vendor
# yaegi-validate interprets the plugin under the yaegi interpreter the same way
# Traefik loads it. Native `go build`/`go test` use the standard compiler and do
# NOT catch yaegi-only incompatibilities (unsupported stdlib symbols, reflection
# edge cases). This target does. Importing the package forces yaegi to interpret
# every file in it plus its vendored deps; CreateConfig + New exercise the
# instantiation path. Pin YAEGI_VERSION to match Traefik's bundled yaegi if you
# need exact parity.
.PHONY: yaegi-validate
yaegi-validate: ## Verify the plugin loads under Traefik's yaegi interpreter
@command -v yaegi >/dev/null 2>&1 || { echo ">> installing yaegi@$(YAEGI_VERSION)"; $(GO) install github.com/traefik/yaegi/cmd/yaegi@$(YAEGI_VERSION); }
@echo ">> interpreting plugin under yaegi (as Traefik does)"
@DO_NOT_TRACK=1 GOFLAGS=-mod=vendor $$(command -v yaegi || echo "$(GOPATH)/bin/yaegi") run ./cmd/yaegicheck/main.go
.PHONY: check
check: vet staticcheck test yaegi-validate ## vet + staticcheck + tests + yaegi load validation
+29 -1
View File
@@ -111,7 +111,8 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
| `logoutURL` | `callbackURL + "/logout"` | RP-initiated logout path. |
| `postLogoutRedirectURI` | `/` | Where to send users after logout. |
| `scopes` | appended to `openid profile email` | Extra OAuth scopes. Set `overrideScopes: true` to replace defaults. |
| `excludedURLs` | none | Prefix-matched paths that bypass auth. |
| `extraAuthParams` | none | Map of extra query parameters appended to the authorization request (e.g. `screen_hint: signup`, `login_hint`, `ui_locales`, `prompt`). Plugin-managed params (`client_id`, `state`, `nonce`, `redirect_uri`, `code_challenge`, `scope`, `response_type`, …) cannot be overridden. |
| `excludedURLs` | none | Paths that bypass auth, matched at a path-segment or file-extension boundary (e.g. `/public` matches `/public`, `/public/sub` and `/public.json`, but **not** `/publicsecret`). |
| `allowedUserDomains` | none | Restrict to email domains. |
| `allowedUsers` | none | Restrict to specific addresses (or claim values when `userIdentifierClaim != email`). |
| `allowedRolesAndGroups` | none | Require any of these roles/groups from ID-token claims. |
@@ -146,6 +147,18 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
## Production gotchas
### Upgrading from an earlier release
- **Sessions are re-issued once.** Session cookies are now AES-256 encrypted
(previously signed only) and their cryptographic lifetime tracks
`sessionMaxAge` (previously a fixed 30 days). Existing cookies become invalid
on upgrade, so users re-authenticate one time.
- **Invalid configuration now fails closed at startup** instead of being
silently accepted: a `sessionEncryptionKey` shorter than 32 bytes, a
`rateLimit` below 10, a missing `callbackURL`, or a non-HTTPS remote
`providerURL` are rejected. Plaintext HTTP is permitted only for loopback
hosts (local development).
### TLS termination at a load balancer
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
@@ -165,6 +178,8 @@ detected" when the same token hits different replicas. Two options:
For IdP-initiated logout (back/front-channel) in multi-replica setups, Redis is
**required** so a logout on one instance invalidates sessions on the others.
Front-channel logout requests must include a matching `iss` query parameter;
requests that omit it are rejected with `400`.
### Multiple middleware instances on the same host
@@ -411,6 +426,19 @@ namespaced claims, Cognito regions, GitLab self-hosted) live in
Set `logLevel: debug` to surface detail.
## Telemetry
On first plugin instantiation this middleware sends a single anonymous
adoption ping — project name, version, timestamp; no identifiers, no
request data, no token contents. Fire-and-forget with a 2-second timeout;
cannot block plugin load or panic.
Local source: [`telemetry.go`](./telemetry.go). Disclosure mirrors
**[oss-telemetry — Disabling telemetry](https://github.com/lukaszraczylo/oss-telemetry#disabling-telemetry)**.
Quick opt-out: set any of `DO_NOT_TRACK=1`, `OSS_TELEMETRY_DISABLED=1`,
or `TRAEFIKOIDC_DISABLE_TELEMETRY=1`.
## License
See [LICENSE](LICENSE).
+4 -2
View File
@@ -484,7 +484,8 @@ func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
session.SetAccessToken(opaqueAccessToken)
session.SetIDToken(idToken)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokensRS(rs)
if !authenticated || needsRefresh || expired {
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
authenticated, needsRefresh, expired)
@@ -623,7 +624,8 @@ func TestAuth0Scenario2StrictMode(t *testing.T) {
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
// In strict mode, this should FAIL (no fallback to ID token)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokensRS(rs)
if authenticated {
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
}
+9 -23
View File
@@ -182,6 +182,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
}
codeVerifier := session.GetCodeVerifier()
if t.enablePKCE && codeVerifier == "" {
t.logger.Error("PKCE is enabled but code verifier is missing from session during callback")
t.sendErrorResponse(rw, req, "Authentication failed: PKCE verifier missing", http.StatusBadRequest)
return
}
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
@@ -263,7 +268,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
// Neutralize open-redirect payloads (e.g. //evil.com, /\evil.com) stored
// from the original request target before using it as the post-login
// redirect target. normalizeLogoutPath forces a host-relative path.
redirectPath = normalizeLogoutPath(incomingPath)
}
session.SetIncomingPath("")
@@ -305,28 +313,6 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// isUserAuthenticated determines the authentication status and refresh requirements.
// It delegates to provider-specific validation methods that handle different token types
// and expiration behaviors.
// Parameters:
// - session: The session data containing authentication tokens.
//
// Returns:
// - authenticated (bool): True if the user has valid tokens.
// - needsRefresh (bool): True if tokens are valid but nearing expiration.
// - expired (bool): True if the session is unauthenticated, the token is missing,
// or the token verification failed for reasons other than nearing/actual expiration.
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if t.isAzureProvider() {
return t.validateAzureTokens(session)
} else if t.isGoogleProvider() {
return t.validateGoogleTokens(session)
}
// Auth0 and other providers can now use standard validation
// which handles opaque tokens generically
return t.validateStandardTokens(session)
}
// isAjaxRequest determines if this is an AJAX request that should receive 401 instead of redirect
func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
xhr := req.Header.Get("X-Requested-With")
+8 -4
View File
@@ -262,7 +262,8 @@ func TestAzureOIDCRegression(t *testing.T) {
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
// Test that CSRF is preserved during Azure validation failures
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := tOidc.validateAzureTokensRS(rs)
// Should not be authenticated due to validation failure
if authenticated {
@@ -453,7 +454,8 @@ func TestValidateGoogleTokens(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
session := tt.setupSession()
auth, refresh, expired := ts.tOidc.validateGoogleTokens(session)
rs := (&requestState{}).captureSession(session)
auth, refresh, expired := ts.tOidc.validateGoogleTokensRS(rs)
if auth != tt.expectedAuth {
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
@@ -637,7 +639,8 @@ func TestIsUserAuthenticated(t *testing.T) {
defer func() { ts.tOidc.issuerURL = originalIssuer }()
session := tt.setupSession()
auth, refresh, expired := ts.tOidc.isUserAuthenticated(session)
rs := (&requestState{}).captureSession(session)
auth, refresh, expired := ts.tOidc.isUserAuthenticatedRS(rs)
if auth != tt.expectedAuth {
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
@@ -762,7 +765,8 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
session := tt.setupSession()
auth, refresh, expired := ts.tOidc.validateAzureTokens(session)
rs := (&requestState{}).captureSession(session)
auth, refresh, expired := ts.tOidc.validateAzureTokensRS(rs)
if auth != tt.expectedAuth {
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
+101 -10
View File
@@ -149,6 +149,94 @@ func parseBearerJOSEHeader(token string) *bearerError {
return nil
}
// headerClaimRuneReason reports why a rune is unsafe to inject into a request
// header value, or "" if the rune is acceptable. Shared core of the bearer-path
// identifier sanitizer and the cookie-path header claim sanitizer: rejects
// control chars (CRLF/header injection), Unicode bidi-override runes (RTL
// spoofing of admin UI / SIEM), and the delimiters , ; = (a comma in a group
// name would inject extra entries into a comma-joined header).
func headerClaimRuneReason(r rune) string {
if reason := headerInjectionRuneReason(r); reason != "" {
return reason
}
// The , ; = delimiters are only unsafe for values placed into delimited or
// list contexts (a comma-joined header, or an identifier downstreams may
// split). They are valid in arbitrary single header values, so this stricter
// check is used for the cookie-path identifier and the group/role list, NOT
// for free-form templated header output (see headerValueReason).
if r == ',' || r == ';' || r == '=' {
return "delimiter character"
}
return ""
}
// headerInjectionRuneReason reports why a rune is unsafe in ANY HTTP header
// value, or "" if acceptable. Rejects control characters (CR/LF header
// injection) and Unicode bidi-override runes (RTL spoofing of admin UIs/SIEMs).
// Unlike headerClaimRuneReason it does NOT reject , ; = which are legitimate in
// free-form header values (e.g. an opaque "Authorization: Bearer <token>").
func headerInjectionRuneReason(r rune) string {
if unicode.IsControl(r) {
return "control character"
}
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
return "bidi-override character"
}
return ""
}
// headerValueReason reports why value is unsafe to forward as a free-form HTTP
// header value, or "" if acceptable. It rejects values over maxLen (maxLen<=0
// disables the check) and values containing control or bidi-override runes, but
// permits , ; = (valid in header values). Empty is allowed. The reason string
// never includes the value, so it is safe to log.
func headerValueReason(value string, maxLen int) string {
if maxLen > 0 && len(value) > maxLen {
return "exceeds max length"
}
for _, r := range value {
if reason := headerInjectionRuneReason(r); reason != "" {
return reason
}
}
return ""
}
// headerClaimValueReason reports why value is unsafe to inject into a
// downstream request header, or "" if it is acceptable. It rejects empty
// values, values exceeding maxLen (maxLen<=0 disables the length check), and
// values containing any rune rejected by headerClaimRuneReason. The reason
// string is safe to log (it never includes the value itself).
func headerClaimValueReason(value string, maxLen int) string {
if value == "" {
return "empty value"
}
if maxLen > 0 && len(value) > maxLen {
return "exceeds max length"
}
for _, r := range value {
if reason := headerClaimRuneReason(r); reason != "" {
return reason
}
}
return ""
}
// sanitizeHeaderClaimValue validates a claim-derived value before it is
// injected into a downstream request header. It trims surrounding whitespace
// and fails closed (ok=false) on empty values, values exceeding maxLen
// (maxLen<=0 disables the length check), or values containing any rune rejected
// by headerClaimRuneReason. Used by the cookie/session path, which — unlike the
// bearer path — does not otherwise sanitize the principal identifier or the
// group/role strings joined into X-User-Groups / X-User-Roles.
func sanitizeHeaderClaimValue(raw string, maxLen int) (string, bool) {
value := strings.TrimSpace(raw)
if headerClaimValueReason(value, maxLen) != "" {
return "", false
}
return value, true
}
// sanitizeBearerIdentifier validates and trims a principal identifier before
// it is injected into request headers. Layered defense: net/http will reject
// CRLF on the wire too, but rejecting early gives clearer error logs and
@@ -163,15 +251,8 @@ func sanitizeBearerIdentifier(raw string, maxLen int) (string, *bearerError) {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier exceeds max length")
}
for _, r := range identifier {
if unicode.IsControl(r) {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains control character")
}
// Unicode bidi-override range (RTL spoofing of admin UI / SIEM).
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains bidi-override character")
}
if r == ',' || r == ';' || r == '=' {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains delimiter character")
if reason := headerClaimRuneReason(r); reason != "" {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains "+reason)
}
}
return identifier, nil
@@ -342,7 +423,17 @@ func (b *bearerFailureTracker) recordSuccess(ip string) {
}
b.mu.Lock()
defer b.mu.Unlock()
delete(b.entries, ip)
e, ok := b.entries[ip]
if !ok {
return
}
// Preserve an active penalty so a single success cannot wipe an in-effect
// lockout; only reset the counter when no penalty is active or it has expired.
now := time.Now()
if e.penaltyUntil.IsZero() || now.After(e.penaltyUntil) {
e.count = 0
e.firstFailureAt = now
}
}
// clientIPForBearer returns the source IP used to key the failure tracker.
+46 -28
View File
@@ -67,31 +67,31 @@ func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
t.Helper()
sm := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("error"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://issuer.example.com",
audience: "https://api.example.com",
clientID: "https://api.example.com",
tokenCache: NewTokenCache(),
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
allowedRolesAndGroups: map[string]struct{}{},
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
ctx: context.Background(),
enableBearerAuth: true,
stripAuthorizationHeader: true,
bearerEmitWWWAuthenticate: true,
bearerOverridesCookie: false,
bearerIdentifierClaim: "sub",
maxIdentifierLength: 256,
maxTokenAge: 24 * time.Hour,
bearerFailureThreshold: 20,
bearerFailureWindow: 60 * time.Second,
bearerFailurePenalty: 60 * time.Second,
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
next: next,
logger: NewLogger("error"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://issuer.example.com",
audience: "https://api.example.com",
clientID: "https://api.example.com",
tokenCache: NewTokenCache(),
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
allowedRolesAndGroups: map[string]struct{}{},
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
ctx: context.Background(),
enableBearerAuth: true,
stripAuthorizationHeader: true,
bearerEmitWWWAuthenticate: true,
bearerOverridesCookie: false,
bearerIdentifierClaim: "sub",
maxIdentifierLength: 256,
maxTokenAge: 24 * time.Hour,
bearerFailureThreshold: 20,
bearerFailureWindow: 60 * time.Second,
bearerFailurePenalty: 60 * time.Second,
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
}
oidc.extractClaimsFunc = extractClaims
close(oidc.initComplete)
@@ -303,15 +303,33 @@ func TestBearerFailureTracker(t *testing.T) {
if b, retry := tr.blocked(ip); !b || retry <= 0 {
t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry)
}
// Success clears the counter.
// A success while a penalty is active must NOT wipe the in-effect lockout
// (otherwise a single success could clear an attacker's penalty).
tr.recordSuccess(ip)
if b, _ := tr.blocked(ip); b {
t.Fatalf("expected unblocked after success")
if b, _ := tr.blocked(ip); !b {
t.Fatalf("expected still blocked after success while penalty active")
}
// Other IPs are unaffected.
if b, _ := tr.blocked("10.0.0.2"); b {
t.Fatalf("unrelated IP should not be blocked")
}
// With an expired penalty, a success resets the counter so a subsequent
// sub-threshold failure does not immediately re-block.
tr2 := newBearerFailureTracker(3, 60*time.Second, 1*time.Millisecond)
const ip2 = "10.0.0.3"
for i := 0; i < 3; i++ {
tr2.recordFailure(ip2)
}
time.Sleep(5 * time.Millisecond) // let the short penalty expire
if b, _ := tr2.blocked(ip2); b {
t.Fatalf("expected unblocked after penalty expiry")
}
tr2.recordSuccess(ip2) // resets count since penalty has passed
tr2.recordFailure(ip2) // single failure, well below threshold
if b, _ := tr2.blocked(ip2); b {
t.Fatalf("expected unblocked: counter should have reset after success")
}
}
// =============================================================================
+23 -2
View File
@@ -16,8 +16,9 @@ type CacheManager struct {
}
var (
globalCacheManagerInstance *CacheManager
cacheManagerInitOnce sync.Once
globalCacheManagerInstance *CacheManager
cacheManagerInitOnce sync.Once
cacheManagerActiveFingerprint string
)
// GetGlobalCacheManager returns a singleton CacheManager instance.
@@ -29,7 +30,9 @@ func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
fp := redisFingerprint(config)
cacheManagerInitOnce.Do(func() {
cacheManagerActiveFingerprint = fp
var redisConfig *RedisConfig
var logger *Logger
@@ -55,9 +58,27 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
}
})
// Warn loudly if a later instance asks for a DIFFERENT explicit Redis
// backend than the one that won initialization: the cache manager is a
// process-global singleton shared across plugin instances (yaegi), so this
// instance's divergent configuration is silently ignored, which would
// otherwise collapse cache/state isolation between routes (rank 9).
if fp != "" && cacheManagerActiveFingerprint != "" && fp != cacheManagerActiveFingerprint {
NewLogger(config.LogLevel).Errorf("cache manager already initialized with Redis backend %q; this instance's Redis backend %q is IGNORED (process-global singleton). Use a single consistent cache configuration across all routes.", cacheManagerActiveFingerprint, fp)
}
return globalCacheManagerInstance
}
// redisFingerprint returns a stable identifier for an explicitly-enabled Redis
// backend (address + key prefix), or "" when Redis is not explicitly enabled.
// Used to detect divergent cache configurations across plugin instances.
func redisFingerprint(config *Config) string {
if config == nil || config.Redis == nil || !config.Redis.Enabled {
return ""
}
return config.Redis.Address + "|" + config.Redis.KeyPrefix
}
// GetSharedTokenBlacklist returns the shared token blacklist cache
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
cm.mu.RLock()
+46
View File
@@ -0,0 +1,46 @@
//go:build ignore
// Command yaegicheck verifies that the traefikoidc plugin can be imported and
// instantiated by the yaegi interpreter — the same way Traefik loads a plugin.
//
// It is run by `make yaegi-validate`. Importing the plugin package forces yaegi
// to interpret every source file in the package (and its vendored
// dependencies), so any construct yaegi cannot handle (unsupported stdlib
// symbol, reflection edge case, etc.) surfaces here rather than at Traefik load
// time. CreateConfig + New additionally exercise the instantiation path
// (session manager, cookie codec, caches, key derivation) under the interpreter.
package main
import (
"context"
"fmt"
"net/http"
"os"
oidc "github.com/lukaszraczylo/traefikoidc"
)
func main() {
cfg := oidc.CreateConfig()
cfg.ProviderURL = "https://accounts.google.com"
cfg.ClientID = "yaegi-check-client"
cfg.ClientSecret = "yaegi-check-secret"
cfg.CallbackURL = "/oauth2/callback"
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef"
cfg.RateLimit = 100
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h, err := oidc.New(context.Background(), next, cfg, "yaegi-check")
if err != nil {
fmt.Println("FAIL: New returned an error under yaegi:", err)
os.Exit(1)
}
if h == nil {
fmt.Println("FAIL: New returned a nil handler under yaegi")
os.Exit(1)
}
if closer, ok := h.(interface{ Close() error }); ok {
_ = closer.Close()
}
fmt.Println("OK: traefikoidc imported + CreateConfig + New succeeded under yaegi")
}
-76
View File
@@ -278,82 +278,6 @@ func TestHTTPClientProfiler_Methods_CoverageBoost(t *testing.T) {
}
}
// =============================================================================
// SECURITY MONITORING TESTS
// =============================================================================
func TestSecurityMonitor_StopCleanupRoutine_CoverageBoost(t *testing.T) {
logger := NewLogger("info")
config := SecurityMonitorConfig{
MaxFailuresPerIP: 5,
FailureWindowMinutes: 15,
BlockDurationMinutes: 30,
RapidFailureThreshold: 3,
CleanupIntervalMinutes: 60,
RetentionHours: 24,
EnablePatternDetection: true,
EnableDetailedLogging: false,
LogSuspiciousOnly: false,
}
sm := NewSecurityMonitor(config, logger)
if sm == nil {
t.Fatal("Expected non-nil SecurityMonitor")
}
// Start cleanup routine first (lowercase method)
sm.startCleanupRoutine()
// Give it a moment to start
time.Sleep(50 * time.Millisecond)
// Stop cleanup routine (public method)
sm.StopCleanupRoutine()
// Stop again should be safe
sm.StopCleanupRoutine()
}
func TestSecurityMonitor_MultipleHandlers_CoverageBoost(t *testing.T) {
logger := NewLogger("info")
config := SecurityMonitorConfig{
MaxFailuresPerIP: 5,
FailureWindowMinutes: 15,
BlockDurationMinutes: 30,
RapidFailureThreshold: 3,
CleanupIntervalMinutes: 60,
RetentionHours: 24,
}
sm := NewSecurityMonitor(config, logger)
// Create handler
handler := &LoggingSecurityEventHandler{logger: logger}
// Register handler using AddEventHandler
sm.AddEventHandler(handler)
// Record a failure to trigger events
sm.RecordAuthenticationFailure("192.168.1.100", "test-agent", "/test", "test_failure", nil)
}
func TestLoggingSecurityEventHandler_HandleSecurityEvent_AllSeverities_CoverageBoost(t *testing.T) {
logger := NewLogger("debug")
handler := &LoggingSecurityEventHandler{logger: logger}
// Severity is a string in this implementation
events := []SecurityEvent{
{Type: "test", Severity: "low", Message: "low severity"},
{Type: "test", Severity: "medium", Message: "medium severity"},
{Type: "test", Severity: "high", Message: "high severity"},
{Type: "test", Severity: "critical", Message: "critical severity"},
}
for _, event := range events {
handler.HandleSecurityEvent(event)
}
}
// =============================================================================
// SESSION MANAGER TESTS
// =============================================================================
+1 -1
View File
@@ -178,7 +178,7 @@ clientSecret: your-client-secret
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
| `rateLimit` | int | `100` | Maximum requests per second |
| `excludedURLs` | []string | none | Paths that bypass authentication |
| `excludedURLs` | []string | none | Paths that bypass authentication, matched at a path-segment or file-extension boundary |
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
-199
View File
@@ -370,21 +370,6 @@ func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, res
return r.saveCredentials(resp)
}
// deleteCredentialsFromStore removes credentials from the configured storage backend
// Falls back to legacy file-based deletion if no store is configured
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
// Use store if available
if r.store != nil {
return r.store.Delete(ctx, r.providerURL)
}
// Fallback to legacy file-based deletion
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// saveCredentials persists client credentials to a file (legacy method)
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
filePath := r.credentialsFilePath()
@@ -423,187 +408,3 @@ func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse,
return &resp, nil
}
// UpdateClientRegistration updates an existing client registration using RFC 7592
// This requires the registration_client_uri and registration_access_token from the original registration
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to update")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Build update request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build update request: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create update request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("update request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read update response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse update response: %w", err)
}
// Update cache
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist updated credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist updated credentials: %v", err)
}
}
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
return &regResp, nil
}
// ReadClientRegistration reads the current client registration using RFC 7592
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to read")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
if err != nil {
return nil, fmt.Errorf("failed to create read request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("read request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse read response: %w", err)
}
return &regResp, nil
}
// DeleteClientRegistration deletes the client registration using RFC 7592
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return fmt.Errorf("no existing registration to delete")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
if err != nil {
return fmt.Errorf("failed to create delete request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return fmt.Errorf("delete request failed: %w", err)
}
defer resp.Body.Close()
// Handle error responses (204 No Content is success)
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
}
// Clear cache
r.mu.Lock()
r.registrationResponse = nil
r.mu.Unlock()
// Remove credentials from storage if persistence is enabled
if r.config.PersistCredentials {
if err := r.deleteCredentialsFromStore(ctx); err != nil {
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
}
}
r.logger.Info("Successfully deleted client registration")
return nil
}
-252
View File
@@ -735,258 +735,6 @@ func TestDCRConfigDefaults(t *testing.T) {
}
}
// TestUpdateClientRegistration tests the RFC 7592 client update functionality
func TestUpdateClientRegistration(t *testing.T) {
updateCalled := false
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPut {
updateCalled = true
// Verify authorization header
if r.Header.Get("Authorization") == "" {
t.Error("Missing Authorization header for update")
}
resp := ClientRegistrationResponse{
ClientID: "updated-client-id",
ClientSecret: "updated-client-secret",
RegistrationAccessToken: "new-access-token",
RegistrationClientURI: r.URL.String(),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
}))
defer server.Close()
dcrConfig := &DynamicClientRegistrationConfig{
Enabled: true,
ClientMetadata: &ClientRegistrationMetadata{
RedirectURIs: []string{"https://example.com/callback"},
},
}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "original-client-id",
ClientSecret: "original-client-secret",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Perform update
ctx := context.Background()
resp, err := registrar.UpdateClientRegistration(ctx)
if err != nil {
t.Fatalf("Update failed: %v", err)
}
if !updateCalled {
t.Error("Update endpoint was not called")
}
if resp.ClientID != "updated-client-id" {
t.Errorf("Updated ClientID mismatch: got %s", resp.ClientID)
}
}
// TestDeleteClientRegistration tests the RFC 7592 client deletion functionality
func TestDeleteClientRegistration(t *testing.T) {
deleteCalled := false
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodDelete {
deleteCalled = true
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
tempDir := t.TempDir()
credentialsFile := filepath.Join(tempDir, "credentials.json")
// Create a credentials file to test deletion
os.WriteFile(credentialsFile, []byte(`{"client_id":"test"}`), 0600)
dcrConfig := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
CredentialsFile: credentialsFile,
}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "test-client-id",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Perform delete
ctx := context.Background()
err := registrar.DeleteClientRegistration(ctx)
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
if !deleteCalled {
t.Error("Delete endpoint was not called")
}
// Verify cache is cleared
if registrar.GetCachedResponse() != nil {
t.Error("Cached response should be cleared after deletion")
}
// Verify credentials file is deleted
if _, err := os.Stat(credentialsFile); !os.IsNotExist(err) {
t.Error("Credentials file should be deleted")
}
}
// TestReadClientRegistration tests the RFC 7592 client read functionality
func TestReadClientRegistration(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
resp := ClientRegistrationResponse{
ClientID: "read-client-id",
ClientSecret: "read-client-secret",
RedirectURIs: []string{"https://example.com/callback"},
ResponseTypes: []string{"code"},
GrantTypes: []string{"authorization_code"},
ApplicationType: "web",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
}))
defer server.Close()
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "original-client-id",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Read registration
ctx := context.Background()
resp, err := registrar.ReadClientRegistration(ctx)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if resp.ClientID != "read-client-id" {
t.Errorf("Read ClientID mismatch: got %s", resp.ClientID)
}
}
// TestOperationsWithoutCachedResponse tests error handling when no cached response exists
func TestOperationsWithoutCachedResponse(t *testing.T) {
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
&http.Client{},
NewLogger("DEBUG"),
dcrConfig,
"https://example.com",
)
ctx := context.Background()
// Test Update without cached response
_, err := registrar.UpdateClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Update should fail without cached response: %v", err)
}
// Test Read without cached response
_, err = registrar.ReadClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Read should fail without cached response: %v", err)
}
// Test Delete without cached response
err = registrar.DeleteClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Delete should fail without cached response: %v", err)
}
}
// TestOperationsWithoutManagementCredentials tests error handling without management URIs
func TestOperationsWithoutManagementCredentials(t *testing.T) {
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
&http.Client{},
NewLogger("DEBUG"),
dcrConfig,
"https://example.com",
)
// Set up cached response WITHOUT management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "test-client-id",
// Missing RegistrationAccessToken and RegistrationClientURI
}
registrar.mu.Unlock()
ctx := context.Background()
// Test Update without management credentials
_, err := registrar.UpdateClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Update should fail without management credentials: %v", err)
}
// Test Read without management credentials
_, err = registrar.ReadClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Read should fail without management credentials: %v", err)
}
// Test Delete without management credentials
err = registrar.DeleteClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Delete should fail without management credentials: %v", err)
}
}
// stringContains is a helper function to check if a string contains a substring
func stringContains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContainsHelper(s, substr))
+6 -5
View File
@@ -539,10 +539,10 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
return true
}
errStr := err.Error()
errStr := strings.ToLower(err.Error())
for _, retryableErr := range re.config.RetryableErrors {
if contains(errStr, retryableErr) {
if contains(errStr, strings.ToLower(retryableErr)) {
return true
}
}
@@ -551,7 +551,7 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
if netErr.Timeout() {
return true
}
errStr := netErr.Error()
errStr := strings.ToLower(netErr.Error())
temporaryPatterns := []string{
"connection refused",
"connection reset",
@@ -859,8 +859,9 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f
// isServiceDegraded checks if a service is currently degraded
func (gd *GracefulDegradation) isServiceDegraded(serviceName string) bool {
gd.mutex.RLock()
defer gd.mutex.RUnlock()
// Uses a write lock because the recovery-timeout branch deletes from the map.
gd.mutex.Lock()
defer gd.mutex.Unlock()
degradedTime, exists := gd.degradedServices[serviceName]
if !exists {
+1
View File
@@ -5,6 +5,7 @@ go 1.24.0
require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/gorilla/sessions v1.3.0
github.com/lukaszraczylo/oss-telemetry v0.2.3
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.14.0
+2
View File
@@ -16,6 +16,8 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/lukaszraczylo/oss-telemetry v0.2.3 h1:xoDtBqeZGmXj7IteiE1M5WMuzeoqag58qEleI0Cf2Ms=
github.com/lukaszraczylo/oss-telemetry v0.2.3/go.mod h1:+Cn78qZo8rc3T9eZt0v3oICYRdd75wORtSidc8lNjDQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
+10 -1
View File
@@ -392,10 +392,19 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
// localRedirect is used when there is no provider end-session endpoint and
// the plugin redirects the browser itself. It must never be an absolute URL
// derived from the request host (X-Forwarded-Host is client-controllable and
// would be an open redirect); use a host-relative path, or the operator's
// own configured absolute URL, instead.
localRedirect := "/"
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
localRedirect = normalizeLogoutPath(postLogoutRedirectURI)
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
} else {
localRedirect = postLogoutRedirectURI
}
// Read endSessionURL with RLock
@@ -414,7 +423,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
http.Redirect(rw, req, localRedirect, http.StatusFound)
}
// BuildLogoutURL constructs a logout URL for the OIDC provider's end session endpoint.
+30 -6
View File
@@ -26,6 +26,10 @@ type sharedTransport struct {
lastUsed time.Time
transport *http.Transport
refCount int
// tlsKey identifies the TLS trust settings (CA pool + InsecureSkipVerify)
// this transport was built with, so the at-limit fallback only reuses a
// transport whose TLS configuration matches the caller's.
tlsKey string
}
var (
@@ -53,19 +57,26 @@ func GetGlobalTransportPool() *SharedTransportPool {
// GetOrCreateTransport gets or creates a shared transport with the given config
func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *http.Transport {
// SECURITY FIX: Check client limit before creating new transport
// SECURITY FIX: Check client limit before creating new transport.
if atomic.LoadInt32(&p.clientCount) >= p.maxClients {
// Return existing transport if limit reached
p.mu.RLock()
defer p.mu.RUnlock()
// At the client limit: only reuse a transport that was built for the
// SAME config (same TLS trust store). refCount is mutated under the
// write lock to avoid a data race, and a transport created for a
// different configuration is never handed back — doing so could apply
// the wrong (possibly verification-disabled) TLS settings to a request.
want := tlsConfigKey(config)
p.mu.Lock()
defer p.mu.Unlock()
for _, shared := range p.transports {
if shared != nil && shared.transport != nil {
if shared != nil && shared.transport != nil && shared.tlsKey == want {
shared.refCount++
shared.lastUsed = time.Now()
return shared.transport
}
}
// If no transport available, return nil (caller should handle)
// No TLS-compatible transport available; return nil so the caller falls
// back to a default, certificate-verifying transport rather than one
// with a different (possibly verification-disabled) trust store.
return nil
}
@@ -125,6 +136,7 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
transport: transport,
refCount: 1,
lastUsed: time.Now(),
tlsKey: tlsConfigKey(config),
}
return transport
@@ -224,6 +236,18 @@ func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
)
}
// tlsConfigKey identifies only the TLS trust settings of a config — the CA pool
// and the InsecureSkipVerify flag. Two configs with the same tlsConfigKey are
// safe to serve from the same transport even if other (non-TLS) parameters such
// as connection limits differ; configs with different TLS settings are not.
func tlsConfigKey(config HTTPClientConfig) string {
skip := "0"
if config.InsecureSkipVerify {
skip = "1"
}
return fmt.Sprintf("%p|%s", config.RootCAs, skip)
}
// Cleanup closes all transports and stops the cleanup goroutine
func (p *SharedTransportPool) Cleanup() {
p.mu.Lock()
+12 -4
View File
@@ -842,10 +842,18 @@ func TestWorkerPool_TaskPanic(t *testing.T) {
t.Error("Timeout waiting for tasks")
}
// Pool should still be functional
metrics := pool.GetMetrics()
if metrics["tasksFailed"].(int64) < 1 {
t.Error("Expected at least one failed task")
// tasksFailed is incremented in the worker's deferred recover(), which runs
// AFTER the panicking task's own `defer wg.Done()`. wg.Wait() above can
// therefore return before the failure is recorded — reading the counter
// immediately is a race that flakes on slow/contended CI runners. Poll until
// the failure lands (or time out).
deadline := time.Now().Add(2 * time.Second)
for pool.GetMetrics()["tasksFailed"].(int64) < 1 {
if time.Now().After(deadline) {
t.Error("Expected at least one failed task")
break
}
time.Sleep(5 * time.Millisecond)
}
}
+23 -1
View File
@@ -155,12 +155,34 @@ func DetermineScheme(req *http.Request, forceHTTPS bool) string {
// It checks X-Forwarded-Host header first (for proxy scenarios),
// then falls back to req.Host.
func DetermineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
if host := sanitizeForwardedHost(req.Header.Get("X-Forwarded-Host")); host != "" {
return host
}
return req.Host
}
// sanitizeForwardedHost returns a single, well-formed host from a (possibly
// comma-separated) X-Forwarded-Host header, or "" if none is usable. It takes
// only the first value and rejects whitespace and control characters, so a
// crafted header cannot inject CRLF, smuggle a second host, or otherwise poison
// the redirect URLs built from the result.
func sanitizeForwardedHost(v string) string {
if v == "" {
return ""
}
if i := strings.IndexByte(v, ','); i >= 0 {
v = v[:i]
}
v = strings.TrimSpace(v)
if v == "" {
return ""
}
if strings.IndexFunc(v, func(r rune) bool { return r < 0x20 || r == 0x7f || r == ' ' }) >= 0 {
return ""
}
return v
}
// BuildFullURL constructs a URL from scheme, host, and path components.
// It handles absolute URLs (returning them as-is) and ensures paths have leading slashes.
func BuildFullURL(scheme, host, path string) string {
+8 -4
View File
@@ -234,7 +234,8 @@ func TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken(t *testing.T
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, graphAccessToken, idToken)
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
output := errBuf.String()
assert.NotContains(t, output, "crypto/rsa: verification error",
@@ -344,7 +345,8 @@ func TestIssue134_Followup_StandardAzureAccessTokenStillVerifies(t *testing.T) {
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, accessToken, idToken)
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
assert.True(t, authenticated, "standard Azure access token must verify and authenticate")
assert.False(t, needsRefresh)
@@ -381,7 +383,8 @@ func TestIssue134_Followup_GraphAccessTokenWithoutIDToken(t *testing.T) {
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, graphAccessToken, "")
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
assert.True(t, authenticated, "Graph token without ID token must remain authenticated (matches existing opaque-token semantics)")
assert.False(t, needsRefresh)
@@ -443,7 +446,8 @@ func TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification(t *test
oidc, _ := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, forgedAccessToken, forgedIDToken)
authenticated, _, _ := oidc.validateAzureTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, _, _ := oidc.validateAzureTokensRS(rs)
assert.False(t, authenticated,
"attacker's forged tokens must not authenticate even when the access token has a nonce header — ID token verification rejects the wrong-key signature")
}
+7 -6
View File
@@ -478,11 +478,10 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
// Test 3: Rate limiting
t.Run("RateLimiting", func(t *testing.T) {
// Reset circuit breaker to closed state for this test
coordinator.circuitBreaker.mutex.Lock()
// Reset circuit breaker to closed state for this test. All fields are
// atomic so we don't need any mutex.
atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed
atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0)
coordinator.circuitBreaker.mutex.Unlock()
// Temporarily increase circuit breaker threshold to not interfere
oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures
@@ -525,9 +524,11 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
time.Sleep(config.CleanupInterval * 3)
// Old sessions should be cleaned up
coordinator.attemptsMutex.RLock()
count := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
count := 0
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
count++
return true
})
// Should have fewer sessions after cleanup
if count > 10 {
+67 -13
View File
@@ -53,10 +53,26 @@ type JWKSet struct {
Keys []JWK `json:"keys"`
}
// JWKCache provides thread-safe caching of JWKS using UniversalCache
// JWKCache provides thread-safe caching of JWKS using UniversalCache.
//
// inflightFetches deduplicates concurrent fetches for the same JWKS URL.
// It replaces a global sync.RWMutex that was previously held for the entire
// HTTP round-trip in GetJWKS: on a cold cache (cold pod, JWK rotation, brief
// network blip) every concurrent request piled up on that single Lock(), and
// under Yaegi each Lock acquisition costs 10-50ms of interpreter-dispatch
// overhead. The singleflight pattern keeps the cold-cache cost O(1) HTTP
// fetch regardless of how many requests are waiting.
type JWKCache struct {
cache *UniversalCache
mutex sync.RWMutex
cache *UniversalCache
inflightFetches sync.Map // map[jwksURL string]*jwksFetch
}
// jwksFetch represents an in-flight JWKS fetch. Done is closed when the fetch
// completes; jwks and err carry the result (one of them is set, never both).
type jwksFetch struct {
done chan struct{}
jwks *JWKSet
err error
}
// JWKCacheInterface defines the contract for JWK caching implementations.
@@ -83,36 +99,58 @@ func NewJWKCache() *JWKCache {
// request refetches from the upstream. JWK rotation is rare and a per-replica
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
// Check cache first
// Fast path: cache hit.
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
if jwks, ok := cachedValue.(*JWKSet); ok {
return jwks, nil
}
}
c.mutex.Lock()
defer c.mutex.Unlock()
// Singleflight: dedupe concurrent fetches per URL key. The first arrival
// performs the HTTP fetch; any later arrival for the same URL waits on
// its done channel and shares the result. No global lock is held during
// the fetch.
candidate := &jwksFetch{done: make(chan struct{})}
if existing, loaded := c.inflightFetches.LoadOrStore(jwksURL, candidate); loaded {
f, _ := existing.(*jwksFetch)
select {
case <-f.done:
return f.jwks, f.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Double-check after acquiring lock
// We're the leader. Make absolutely sure the result fields and the
// in-flight map entry are cleaned up before any waiter unblocks.
defer func() {
c.inflightFetches.Delete(jwksURL)
close(candidate.done)
}()
// Re-check the cache in case a concurrent fetch completed between our
// initial miss and our LoadOrStore win.
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
if jwks, ok := cachedValue.(*JWKSet); ok {
candidate.jwks = jwks
return jwks, nil
}
}
// Fetch from URL
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
if err != nil {
candidate.err = err
return nil, err
}
if len(jwks.Keys) == 0 {
return nil, fmt.Errorf("JWKS response contains no keys")
candidate.err = fmt.Errorf("JWKS response contains no keys")
return nil, candidate.err
}
// Cache for 1 hour
// Cache for 1 hour.
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
candidate.jwks = jwks
return jwks, nil
}
@@ -162,6 +200,22 @@ func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
if k.Kid == "" {
continue
}
// Skip keys that are not intended for signature verification.
if k.Use != "" && k.Use != "sig" {
continue
}
if len(k.KeyOps) > 0 {
hasVerify := false
for _, op := range k.KeyOps {
if op == "verify" {
hasVerify = true
break
}
}
if !hasVerify {
continue
}
}
var pub crypto.PublicKey
var err error
switch k.Kty {
@@ -204,11 +258,11 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) // Safe to ignore: reading error body for diagnostics
body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024)) // Safe to ignore: reading error body for diagnostics
return nil, fmt.Errorf("JWKS fetch failed with status %d: %s", resp.StatusCode, body)
}
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("error reading JWKS response: %w", err)
}
+5 -2
View File
@@ -134,8 +134,11 @@ func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
if iss != "" && iss != expectedIssuer {
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
// Require a matching issuer. An empty iss must be rejected too: accepting a
// missing issuer would let an unauthenticated attacker force-logout any
// session whose sid is known by simply omitting iss.
if iss == "" || iss != expectedIssuer {
t.logger.Errorf("Front-channel logout: issuer validation failed: got %q, expected %q", iss, expectedIssuer)
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
return
}
+34 -27
View File
@@ -125,10 +125,14 @@ func TestFrontchannelLogoutBasic(t *testing.T) {
expectedStatus: http.StatusOK,
},
{
name: "Valid front-channel logout without issuer",
// Front-channel logout MUST carry a matching issuer. A request
// omitting iss is rejected so an unauthenticated attacker cannot
// force-logout a session whose sid is known by simply leaving iss
// out (audit rank 30).
name: "Missing issuer is rejected",
method: http.MethodGet,
queryParams: map[string]string{"sid": "session456"},
expectedStatus: http.StatusOK,
expectedStatus: http.StatusBadRequest,
},
}
@@ -407,17 +411,17 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) {
})
oidc := &TraefikOidc{
next: nextHandler,
logger: NewLogger("debug"),
enableBackchannelLogout: true,
backchannelLogoutPath: "/backchannel-logout",
sessionInvalidationCache: mockCache,
clientID: "test-client",
issuerURL: "https://provider.example.com",
initComplete: make(chan struct{}),
firstRequestReceived: true,
metadataRefreshStarted: true,
logoutURLPath: "/logout",
next: nextHandler,
logger: NewLogger("debug"),
enableBackchannelLogout: true,
backchannelLogoutPath: "/backchannel-logout",
sessionInvalidationCache: mockCache,
clientID: "test-client",
issuerURL: "https://provider.example.com",
initComplete: make(chan struct{}),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
logoutURLPath: "/logout",
}
close(oidc.initComplete)
@@ -449,22 +453,23 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
})
oidc := &TraefikOidc{
next: nextHandler,
logger: NewLogger("debug"),
enableFrontchannelLogout: true,
frontchannelLogoutPath: "/frontchannel-logout",
sessionInvalidationCache: mockCache,
clientID: "test-client",
issuerURL: "https://provider.example.com",
initComplete: make(chan struct{}),
firstRequestReceived: true,
metadataRefreshStarted: true,
logoutURLPath: "/logout",
next: nextHandler,
logger: NewLogger("debug"),
enableFrontchannelLogout: true,
frontchannelLogoutPath: "/frontchannel-logout",
sessionInvalidationCache: mockCache,
clientID: "test-client",
issuerURL: "https://provider.example.com",
initComplete: make(chan struct{}),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
logoutURLPath: "/logout",
}
close(oidc.initComplete)
// Request to front-channel logout path with valid sid should succeed
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session", nil)
// Request to front-channel logout path with valid sid + matching issuer
// should succeed. The issuer is now required (audit rank 30), so supply it.
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session&iss=https://provider.example.com", nil)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
@@ -1432,7 +1437,9 @@ func TestFrontchannelLogoutCacheControl(t *testing.T) {
issuerURL: "https://provider.example.com",
}
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123", nil)
// Issuer is now required (audit rank 30); supply a matching one so the
// successful-logout cache headers can be asserted.
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123&iss=https://provider.example.com", nil)
rw := httptest.NewRecorder()
oidc.handleFrontchannelLogout(rw, req)
+93 -14
View File
@@ -9,6 +9,7 @@ import (
"encoding/hex"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"strings"
@@ -16,6 +17,7 @@ import (
"text/template"
"time"
telemetry "github.com/lukaszraczylo/oss-telemetry"
"golang.org/x/time/rate"
)
@@ -23,6 +25,11 @@ const (
ConstSessionTimeout = 86400
)
// telemetryStartupOnce keeps the anonymous "plugin loaded" ping to one per
// process. Traefik calls New once per route that uses the plugin; oss-telemetry
// does not deduplicate client-side (the server does), so the gate stays here.
var telemetryStartupOnce sync.Once
// isTestMode detects if the code is running in a test environment.
func isTestMode() bool {
if os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS") == "1" {
@@ -89,6 +96,13 @@ var defaultExcludedURLs = map[string]struct{}{
// - The configured TraefikOidc handler ready to process requests.
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
telemetryStartupOnce.Do(func() {
// Only stamped release builds phone home; dev/local/test builds keep the
// devPluginVersion sentinel (see version.go) and stay silent.
if traefikoidcPluginVersion != devPluginVersion {
telemetry.Send("traefikoidc", traefikoidcPluginVersion)
}
})
return NewWithContext(ctx, config, next, name)
}
@@ -99,18 +113,18 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
config = CreateConfig()
}
if config.SessionEncryptionKey == "" {
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
logger := NewLogger(config.LogLevel)
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
if runtime.Compiler == "yaegi" {
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
logger.Infof("Session encryption key is too short; using default key for analyzer")
} else {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
// Fail closed on invalid configuration. Validate() enforces the security
// constraints (required fields, HTTPS-only URLs, key length, excludedURLs
// safety, rate-limit floor, audience format, ...) that were previously
// unenforced because this constructor never called it. Crucially it rejects
// an empty or too-short SessionEncryptionKey instead of silently
// substituting a public hardcoded key, which would let an attacker forge
// any session. Traefik's yaegi plugin analyzer supplies a valid key via
// .traefik.yml testData, so it passes; only misconfigured deployments fail.
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
// Setup HTTP client
caPool, err := config.loadCACertPool()
@@ -201,6 +215,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}(),
forceHTTPS: config.ForceHTTPS,
enablePKCE: config.EnablePKCE,
extraAuthParams: config.ExtraAuthParams,
overrideScopes: config.OverrideScopes,
strictAudienceValidation: config.StrictAudienceValidation,
allowOpaqueTokens: config.AllowOpaqueTokens,
@@ -221,7 +236,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
httpClient: httpClient,
tokenHTTPClient: tokenHTTPClient,
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedUserDomains: createCaseInsensitiveStringMap(config.AllowedUserDomains),
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
@@ -332,7 +347,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
// Convert sessionMaxAge from seconds to duration (0 will use default 24 hours)
sessionMaxAge := time.Duration(config.SessionMaxAge) * time.Second
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger) // Safe to ignore: session manager creation with fallback to defaults
sessionManager, err := NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger)
if err != nil {
cancelFunc()
return nil, fmt.Errorf("failed to create session manager: %w", err)
}
t.sessionManager = sessionManager
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
// Initialize token resilience manager with default configuration
@@ -423,6 +443,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
// Add reference for this instance
rm.AddReference(name)
registerLiveInstance()
// Initialize metadata in a goroutine with proper tracking
if t.goroutineWG != nil {
@@ -500,13 +521,58 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
// Parameters:
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
// SSRF defense (audit ranks 3 & 4): a discovery document is attacker-
// influenced when the provider or its TLS is compromised. Reject any
// discovered endpoint pointed at a blocked address before the plugin issues
// outbound requests to it, so it can never be used to reach the cloud
// metadata service or an internal host.
allowLoopback := false
if pu, err := url.Parse(t.providerURL); err == nil {
allowLoopback = isLoopbackHost(pu.Hostname())
}
sanitize := func(name, raw string) string {
if err := t.validateDiscoveredEndpoint(raw, allowLoopback); err != nil {
t.logger.Errorf("Ignoring discovered %s endpoint %q: %v", name, raw, err)
return ""
}
return raw
}
metadata.JWKSURL = sanitize("jwks_uri", metadata.JWKSURL)
metadata.AuthURL = sanitize("authorization", metadata.AuthURL)
metadata.TokenURL = sanitize("token", metadata.TokenURL)
metadata.RevokeURL = sanitize("revocation", metadata.RevokeURL)
metadata.EndSessionURL = sanitize("end_session", metadata.EndSessionURL)
metadata.RegistrationURL = sanitize("registration", metadata.RegistrationURL)
metadata.IntrospectionURL = sanitize("introspection", metadata.IntrospectionURL)
// The introspection request authenticates with the client secret via HTTP
// Basic, so the endpoint must live on the same host as the operator-
// configured provider; otherwise a poisoned discovery document could
// exfiltrate the client secret to an attacker-controlled host.
if metadata.IntrospectionURL != "" && t.providerURL != "" && !sameHost(metadata.IntrospectionURL, t.providerURL) {
t.logger.Errorf("Ignoring introspection endpoint %q: host does not match configured providerURL", metadata.IntrospectionURL)
metadata.IntrospectionURL = ""
}
// Pin the discovered issuer to the operator-configured provider host. The
// issuer is the trust anchor for JWT issuer validation, so a poisoned
// discovery document advertising an attacker-chosen issuer must never be
// stored. Real providers (Google, Azure, Keycloak, Okta, Auth0) keep the
// issuer on the same host as the configured providerURL. On mismatch, leave
// issuerURL empty/unchanged so downstream issuer validation fails closed
// rather than trusting the attacker-chosen value.
discoveredIssuer := metadata.Issuer
if discoveredIssuer != "" && t.providerURL != "" && !sameHost(discoveredIssuer, t.providerURL) {
t.logger.Errorf("Ignoring discovered issuer %q: host does not match configured providerURL", discoveredIssuer)
discoveredIssuer = ""
}
t.metadataMu.Lock()
t.jwksURL = metadata.JWKSURL
t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.issuerURL = discoveredIssuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662)
@@ -516,6 +582,19 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
introspectionURL := t.introspectionURL
registrationURL := t.registrationURL
// Publish the read-mostly URL bundle atomically. Hot-path readers Load
// this directly instead of acquiring metadataMu.RLock per request.
t.metadataSnapshot.Store(&MetadataSnapshot{
IssuerURL: discoveredIssuer,
JWKSURL: metadata.JWKSURL,
TokenURL: metadata.TokenURL,
AuthURL: metadata.AuthURL,
RevocationURL: metadata.RevokeURL,
EndSessionURL: metadata.EndSessionURL,
IntrospectionURL: metadata.IntrospectionURL,
RegistrationURL: metadata.RegistrationURL,
})
t.metadataMu.Unlock()
// Log introspection endpoint availability for opaque token support
+2
View File
@@ -194,6 +194,7 @@ func TestGoroutineLeakPrevention_MultipleInstances(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
handler, err := New(ctx, nil, config, "test")
if err != nil {
@@ -322,6 +323,7 @@ func TestGoroutineLeakPrevention_BackgroundTaskCleanup(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
handler, err := New(ctx, nil, config, "test")
if err != nil {
+70 -68
View File
@@ -8,6 +8,7 @@ import (
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
@@ -25,38 +26,47 @@ func TestInitializeMetadata(t *testing.T) {
name: "successful metadata initialization",
providerURL: "",
setupMock: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Issuer must share the host with providerURL (the httptest
// server), otherwise the discovery doc is rejected as poisoned
// (audit ranks 21/22). Real providers keep issuer + endpoints on
// the same host, so derive them all from the server URL.
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: "https://provider.example.com",
AuthURL: "https://provider.example.com/auth",
TokenURL: "https://provider.example.com/token",
JWKSURL: "https://provider.example.com/jwks",
RevokeURL: "https://provider.example.com/revoke",
EndSessionURL: "https://provider.example.com/logout",
Issuer: srv.URL,
AuthURL: srv.URL + "/auth",
TokenURL: srv.URL + "/token",
JWKSURL: srv.URL + "/jwks",
RevokeURL: srv.URL + "/revoke",
EndSessionURL: srv.URL + "/logout",
})
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
return srv
},
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
if oidc.authURL != "https://provider.example.com/auth" {
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
}
if oidc.tokenURL != "https://provider.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
}
if oidc.jwksURL != "https://provider.example.com/jwks" {
if oidc.jwksURL == "" || !strings.HasSuffix(oidc.jwksURL, "/jwks") {
t.Errorf("expected jwksURL to be set, got %s", oidc.jwksURL)
}
if oidc.revocationURL != "https://provider.example.com/revoke" {
if oidc.revocationURL == "" || !strings.HasSuffix(oidc.revocationURL, "/revoke") {
t.Errorf("expected revocationURL to be set, got %s", oidc.revocationURL)
}
if oidc.endSessionURL != "https://provider.example.com/logout" {
if oidc.endSessionURL == "" || !strings.HasSuffix(oidc.endSessionURL, "/logout") {
t.Errorf("expected endSessionURL to be set, got %s", oidc.endSessionURL)
}
if oidc.issuerURL == "" {
t.Errorf("expected issuerURL to be pinned to provider host, got empty")
}
},
wantPanic: false,
},
@@ -115,24 +125,27 @@ func TestInitializeMetadata(t *testing.T) {
name: "partial metadata response",
providerURL: "",
setupMock: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Issuer host must match providerURL (audit ranks 21/22).
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Only return some fields
json.NewEncoder(w).Encode(map[string]string{
"issuer": "https://partial.example.com",
"authorization_endpoint": "https://partial.example.com/auth",
"token_endpoint": "https://partial.example.com/token",
"issuer": srv.URL,
"authorization_endpoint": srv.URL + "/auth",
"token_endpoint": srv.URL + "/token",
// Missing jwks_uri, revocation_endpoint, end_session_endpoint
})
}
}))
return srv
},
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
if oidc.authURL != "https://partial.example.com/auth" {
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
}
if oidc.tokenURL != "https://partial.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
}
// JWKS URL and others may be empty
@@ -197,20 +210,22 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
requestCount := 0
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
requestCount++
mu.Unlock()
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: "https://concurrent.example.com",
AuthURL: "https://concurrent.example.com/auth",
TokenURL: "https://concurrent.example.com/token",
JWKSURL: "https://concurrent.example.com/jwks",
RevokeURL: "https://concurrent.example.com/revoke",
EndSessionURL: "https://concurrent.example.com/logout",
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
RevokeURL: server.URL + "/revoke",
EndSessionURL: server.URL + "/logout",
})
}
}))
@@ -249,7 +264,7 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
oidc.initializeMetadata(server.URL)
// Verify initialization
if oidc.tokenURL != "https://concurrent.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Errorf("expected tokenURL to be set")
}
}()
@@ -341,17 +356,19 @@ func TestProviderDetection(t *testing.T) {
// TestInitializationWaiting tests waiting for initialization to complete
func TestInitializationWaiting(t *testing.T) {
t.Run("wait for initialization completion", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Delay response to simulate slow initialization
time.Sleep(100 * time.Millisecond)
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: "https://slow.example.com",
AuthURL: "https://slow.example.com/auth",
TokenURL: "https://slow.example.com/token",
JWKSURL: "https://slow.example.com/jwks",
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
})
}
}))
@@ -388,7 +405,7 @@ func TestInitializationWaiting(t *testing.T) {
select {
case <-oidc.initComplete:
// Success
if oidc.tokenURL != "https://slow.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Error("expected tokenURL to be set after initialization")
}
case <-time.After(2 * time.Second):
@@ -397,17 +414,19 @@ func TestInitializationWaiting(t *testing.T) {
})
t.Run("multiple waiters for initialization", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Delay to ensure multiple waiters
time.Sleep(50 * time.Millisecond)
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: "https://multi.example.com",
AuthURL: "https://multi.example.com/auth",
TokenURL: "https://multi.example.com/token",
JWKSURL: "https://multi.example.com/jwks",
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
})
}
}))
@@ -452,7 +471,7 @@ func TestInitializationWaiting(t *testing.T) {
select {
case <-oidc.initComplete:
// All waiters should see the same initialized state
if oidc.tokenURL != "https://multi.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Errorf("waiter %d: expected tokenURL to be set", id)
}
case <-time.After(2 * time.Second):
@@ -484,9 +503,8 @@ func TestFirstRequestHandling(t *testing.T) {
defer server.Close()
oidc := &TraefikOidc{
providerURL: server.URL,
firstRequestReceived: false,
firstRequestMutex: sync.Mutex{},
providerURL: server.URL,
firstRequestStarted: 0,
httpClient: &http.Client{
Timeout: 5 * time.Second,
},
@@ -508,19 +526,13 @@ func TestFirstRequestHandling(t *testing.T) {
},
}
// Simulate first request processing
oidc.firstRequestMutex.Lock()
if !oidc.firstRequestReceived {
oidc.firstRequestReceived = true
oidc.firstRequestMutex.Unlock()
// Simulate first request processing — single-firing via CAS.
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
// This would normally be called asynchronously
go func() {
oidc.initializeMetadata(server.URL)
// initComplete is closed internally by initializeMetadata
}()
} else {
oidc.firstRequestMutex.Unlock()
}
// Wait for initialization
@@ -556,9 +568,8 @@ func TestFirstRequestHandling(t *testing.T) {
defer server.Close()
oidc := &TraefikOidc{
providerURL: server.URL,
firstRequestReceived: false,
firstRequestMutex: sync.Mutex{},
providerURL: server.URL,
firstRequestStarted: 0,
httpClient: &http.Client{
Timeout: 5 * time.Second,
},
@@ -580,31 +591,22 @@ func TestFirstRequestHandling(t *testing.T) {
},
}
// Simulate multiple concurrent "first" requests
// Simulate multiple concurrent "first" requests — only one CAS winner
// fires the bootstrap path.
const numRequests = 10
var wg sync.WaitGroup
wg.Add(numRequests)
initStarted := 0
var initMu sync.Mutex
var initStarted int32
for i := 0; i < numRequests; i++ {
go func() {
defer wg.Done()
oidc.firstRequestMutex.Lock()
if !oidc.firstRequestReceived {
oidc.firstRequestReceived = true
oidc.firstRequestMutex.Unlock()
initMu.Lock()
initStarted++
initMu.Unlock()
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
atomic.AddInt32(&initStarted, 1)
// Only one should actually start initialization
oidc.initializeMetadata(server.URL)
} else {
oidc.firstRequestMutex.Unlock()
}
}()
}
@@ -612,8 +614,8 @@ func TestFirstRequestHandling(t *testing.T) {
wg.Wait()
// Verify only one initialization was started
if initStarted != 1 {
t.Errorf("expected exactly 1 initialization, got %d", initStarted)
if atomic.LoadInt32(&initStarted) != 1 {
t.Errorf("expected exactly 1 initialization, got %d", atomic.LoadInt32(&initStarted))
}
// The metadata endpoint might be called once or not at all depending on timing
+28 -28
View File
@@ -61,8 +61,8 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com", // Required for initialization check
}
close(oidc.initComplete)
@@ -92,8 +92,8 @@ func TestServeHTTP_EventStream(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
@@ -175,8 +175,8 @@ func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
@@ -272,8 +272,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close this to simulate timeout
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
}
req := httptest.NewRequest("GET", "/protected", nil)
@@ -307,8 +307,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -337,8 +337,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -367,8 +367,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -740,8 +740,8 @@ func TestMinimalHeaders(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
minimalHeaders: tt.minimalHeaders,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
@@ -817,8 +817,8 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
minimalHeaders: true, // Enable minimal headers
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
@@ -903,8 +903,8 @@ func TestStripAuthCookies(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: tt.stripAuthCookies,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
@@ -987,8 +987,8 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
@@ -1034,8 +1034,8 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
@@ -1085,8 +1085,8 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
@@ -1148,8 +1148,8 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
+72 -48
View File
@@ -16,6 +16,7 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -1874,14 +1875,14 @@ func TestHandleLogout(t *testing.T) {
},
endSessionURL: "",
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
expectedURL: "/",
host: "test-host",
},
{
name: "Logout with empty session",
setupSession: func(session *SessionData) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
expectedURL: "/",
host: "test-host",
},
{
@@ -2348,19 +2349,22 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
t.Skip("Skipping test in short mode")
}
// Create mock provider metadata server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create mock provider metadata server. Issuer + endpoints must share the
// host with ProviderURL (the httptest server), otherwise the discovery doc
// is rejected as poisoned (audit ranks 21/22). Derive them from the server.
var mockServer *httptest.Server
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
RevokeURL: "https://test-issuer.com/revoke",
EndSessionURL: "https://test-issuer.com/end-session",
Issuer: mockServer.URL,
AuthURL: mockServer.URL + "/auth",
TokenURL: mockServer.URL + "/token",
JWKSURL: mockServer.URL + "/jwks",
RevokeURL: mockServer.URL + "/revoke",
EndSessionURL: mockServer.URL + "/end-session",
}
json.NewEncoder(w).Encode(metadata)
}))
@@ -2373,6 +2377,7 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
RateLimit: 100,
}
// Create multiple middleware instances
@@ -2413,18 +2418,20 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
t.Fatalf("Middleware instance %d failed to initialize", i)
}
// Verify each instance has its own unique configuration
if m.issuerURL != "https://test-issuer.com" {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
// Verify each instance has its own unique configuration. Issuer is now
// pinned to the provider host (audit ranks 21/22), so it equals the
// mock server URL rather than a fixed literal.
if m.issuerURL != mockServer.URL {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, mockServer.URL, m.issuerURL)
}
if m.authURL != "https://test-issuer.com/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
if m.authURL != mockServer.URL+"/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, mockServer.URL+"/auth", m.authURL)
}
if m.tokenURL != "https://test-issuer.com/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
if m.tokenURL != mockServer.URL+"/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, mockServer.URL+"/token", m.tokenURL)
}
if m.jwksURL != "https://test-issuer.com/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
if m.jwksURL != mockServer.URL+"/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, mockServer.URL+"/jwks", m.jwksURL)
}
if m.redirURLPath != routes[i]+"/callback" {
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
@@ -2438,15 +2445,16 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
m.ServeHTTP(rr, req)
// Should redirect to auth URL since not authenticated
// Should redirect (302) to the auth flow since not authenticated. The
// absolute auth URL is not asserted here: with issuer pinning (audit
// ranks 21/22) the discovery host equals the httptest server host,
// which is loopback, so buildAuthURL's SSRF guard legitimately refuses
// to emit a loopback authorization URL in this test environment. The
// per-instance auth/token/jwks/issuer URLs were already verified above;
// here we only confirm each instance independently triggers a redirect.
if rr.Code != http.StatusFound {
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
}
location := rr.Header().Get("Location")
if !strings.Contains(location, "https://test-issuer.com/auth") {
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
}
}
}
@@ -2459,33 +2467,43 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
}
// Create two mock provider metadata servers simulating different Keycloak realms
realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Issuer + endpoints must share the host with each realm's ProviderURL
// (the httptest server), otherwise the discovery doc is rejected as
// poisoned (audit ranks 21/22). Keep the distinguishing /realms/realmN
// path so the per-realm isolation assertions below still hold, but base
// the host on the server URL — which is exactly what a same-host Keycloak
// deployment looks like.
var realm1Server *httptest.Server
realm1Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
base := realm1Server.URL + "/realms/realm1"
metadata := ProviderMetadata{
Issuer: "https://keycloak.example.com/realms/realm1",
AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth",
TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token",
JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs",
EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout",
Issuer: base,
AuthURL: base + "/protocol/openid-connect/auth",
TokenURL: base + "/protocol/openid-connect/token",
JWKSURL: base + "/protocol/openid-connect/certs",
EndSessionURL: base + "/protocol/openid-connect/logout",
}
json.NewEncoder(w).Encode(metadata)
}))
defer realm1Server.Close()
realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var realm2Server *httptest.Server
realm2Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
base := realm2Server.URL + "/realms/realm2"
metadata := ProviderMetadata{
Issuer: "https://keycloak.example.com/realms/realm2",
AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth",
TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token",
JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs",
EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout",
Issuer: base,
AuthURL: base + "/protocol/openid-connect/auth",
TokenURL: base + "/protocol/openid-connect/token",
JWKSURL: base + "/protocol/openid-connect/certs",
EndSessionURL: base + "/protocol/openid-connect/logout",
}
json.NewEncoder(w).Encode(metadata)
}))
@@ -2499,6 +2517,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
CallbackURL: "/realm1/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
CookiePrefix: "_oidc_realm1_",
RateLimit: 100,
}
// Config for realm2
@@ -2509,6 +2528,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
CallbackURL: "/realm2/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
CookiePrefix: "_oidc_realm2_",
RateLimit: 100,
}
// Create middleware instances for both realms
@@ -2607,8 +2627,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
providerAvailable := false
var mu sync.Mutex
// Create mock provider that initially fails, then becomes available
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create mock provider that initially fails, then becomes available.
// Issuer + endpoints must share the host with ProviderURL (audit ranks
// 21/22), so derive them from the server URL.
var mockServer *httptest.Server
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
available := providerAvailable
mu.Unlock()
@@ -2620,11 +2643,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
if r.URL.Path == "/.well-known/openid-configuration" {
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
EndSessionURL: "https://test-issuer.com/logout",
Issuer: mockServer.URL,
AuthURL: mockServer.URL + "/auth",
TokenURL: mockServer.URL + "/token",
JWKSURL: mockServer.URL + "/jwks",
EndSessionURL: mockServer.URL + "/logout",
}
json.NewEncoder(w).Encode(metadata)
return
@@ -2639,6 +2662,7 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
RateLimit: 100,
}
// Create middleware while provider is unavailable
@@ -2685,10 +2709,9 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
providerAvailable = true
mu.Unlock()
// Reset the retry timer to allow immediate retry
m.metadataRetryMutex.Lock()
m.lastMetadataRetryTime = time.Time{} // Reset to zero time
m.metadataRetryMutex.Unlock()
// Reset the retry timer to allow immediate retry. The field is atomic
// now, so no lock is needed.
atomic.StoreInt64(&m.lastMetadataRetryNano, 0)
// Second request should trigger recovery attempt
req2 := httptest.NewRequest("GET", "/protected", nil)
@@ -4552,6 +4575,7 @@ func TestNewWithScopeAppending(t *testing.T) {
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
Scopes: tc.configScopes,
RateLimit: 100,
}
// Create middleware instance
+1
View File
@@ -1652,6 +1652,7 @@ func TestGoroutineLeaks(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client"
config.ClientSecret = "test-secret"
config.CallbackURL = "/callback"
handler, err := New(context.Background(), nil, config, "test")
require.NoError(t, err)
+11 -1
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
@@ -141,10 +142,19 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
}
var metadata ProviderMetadata
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&metadata); err != nil {
return nil, fmt.Errorf("failed to decode metadata: %w", err)
}
// Pin the advertised issuer to the configured provider host. The issuer is
// the trust anchor for JWT issuer validation; rejecting a mismatch here
// ensures a poisoned discovery document advertising an attacker-chosen
// issuer is never cached or returned. Real providers (Google, Azure,
// Keycloak, Okta, Auth0) keep the issuer on the same host as providerURL.
if metadata.Issuer != "" && !sameHost(metadata.Issuer, providerURL) {
return nil, fmt.Errorf("discovery issuer %q host does not match provider %q", metadata.Issuer, providerURL)
}
// Cache for 1 hour by default
if err := mc.Set(providerURL, &metadata, 1*time.Hour); err != nil {
mc.logger.Errorf("Failed to cache metadata: %v", err)
+220 -29
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
@@ -145,19 +146,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if !strings.HasPrefix(req.URL.Path, "/health") {
t.firstRequestMutex.Lock()
if !t.firstRequestReceived {
t.firstRequestReceived = true
// Lock-free one-shot bootstrap. The previous firstRequestMutex.Lock()
// fired on EVERY non-health request forever (even after the boolean
// flipped true), which under Yaegi added a per-request serialization
// point. CAS gives single-firing semantics with zero steady-state cost.
if atomic.CompareAndSwapInt32(&t.firstRequestStarted, 0, 1) {
t.logger.Debug("Starting background tasks on first request")
t.startTokenCleanup()
if !t.metadataRefreshStarted && t.providerURL != "" {
t.metadataRefreshStarted = true
if t.providerURL != "" &&
atomic.CompareAndSwapInt32(&t.metadataRefreshStartedAtomic, 0, 1) {
// Metadata refresh is handled by singleton resource manager
t.startMetadataRefresh(t.providerURL)
}
}
t.firstRequestMutex.Unlock()
}
// Evaluate auth-bypass once, before waiting for initialization. Excluded
@@ -207,20 +209,31 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
// Read issuerURL with RLock
t.metadataMu.RLock()
issuerURL := t.issuerURL
t.metadataMu.RUnlock()
// Read issuerURL via atomic snapshot when available — replaces the
// metadataMu.RLock that previously fired on every non-bypass request.
// Under Yaegi each RLock acquisition costs 1-5ms of interpreter
// dispatch; the snapshot is a single atomic.Value.Load. Falls back
// to the legacy field+RLock for paths that haven't published a
// snapshot yet (notably some test setups that initialize the struct
// fields directly).
var issuerURL string
if snap := t.metadataSnap(); snap != nil {
issuerURL = snap.IssuerURL
} else {
t.metadataMu.RLock()
issuerURL = t.issuerURL
t.metadataMu.RUnlock()
}
if issuerURL == "" {
// Provider metadata initialization failed - try to recover
// Retry every 30 seconds to allow automatic recovery when provider comes back online
t.metadataRetryMutex.Lock()
shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second
if shouldRetry {
t.lastMetadataRetryTime = time.Now()
}
t.metadataRetryMutex.Unlock()
// Provider metadata initialization failed - try to recover.
// Retry every 30 seconds to allow automatic recovery. Lock-free
// throttle via CAS on lastMetadataRetryNano: one goroutine wins
// the window, others see shouldRetry=false.
nowNano := time.Now().UnixNano()
last := atomic.LoadInt64(&t.lastMetadataRetryNano)
shouldRetry := time.Duration(nowNano-last) >= 30*time.Second &&
atomic.CompareAndSwapInt64(&t.lastMetadataRetryNano, last, nowNano)
if shouldRetry && t.providerURL != "" {
t.logger.Info("Attempting to recover OIDC provider metadata...")
@@ -298,6 +311,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Capture per-request state: one RLock on sd.sessionMutex covers all the
// getter values the handler chain needs (instead of 5-7 separate
// session.GetX() calls each acquiring their own RLock under Yaegi).
// metadataSnap is also stored once so downstream handlers don't repeat
// the atomic.Value.Load.
rs := (&requestState{
scheme: scheme,
host: host,
redirectURL: redirectURL,
next: t.next,
metadata: t.metadataSnap(),
}).captureSession(session)
// Check if the current request is the OIDC callback
t.logger.Debugf("Checking callback URL match: request_path=%q, configured_callback=%q", req.URL.Path, t.redirURLPath)
if req.URL.Path == t.redirURLPath {
@@ -307,7 +333,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
t.logger.Debugf("Callback URL did not match (request_path=%q != configured=%q), continuing auth flow", req.URL.Path, t.redirURLPath)
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
// Token validation reads session via the captured snapshot — saves ~21
// sd.sessionMutex.RLock acquisitions (Yaegi-dispatched, ~1-5ms each)
// across the validation path.
authenticated, needsRefresh, expired := t.isUserAuthenticatedRS(rs)
if expired {
t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth")
@@ -315,7 +344,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
userIdentifier := session.GetUserIdentifier()
userIdentifier := rs.userIdentifier
// User authorization check
if authenticated && userIdentifier != "" {
if !t.isAllowedUser(userIdentifier) {
@@ -332,11 +361,11 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// methods (validateAzureTokens/validateStandardTokens) before reaching this point.
// Redundant validation here was causing issues with Azure AD tokens that have
// JWT format but unverifiable signatures. See issue #89.
t.processAuthorizedRequest(rw, req, session, redirectURL)
t.processAuthorizedRequestRS(rw, req, rs)
return
}
refreshTokenPresent := session.GetRefreshToken() != ""
refreshTokenPresent := rs.refreshToken != ""
// Decide whether to answer with 401 instead of a redirect. AJAX requests
// cannot follow a 302 into an IdP, and sub-resource loads (script/image/
@@ -443,6 +472,96 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// - req: The HTTP request to process.
// - session: The user's session data containing tokens and claims.
// - redirectURL: The callback URL for re-authentication if needed.
//
// processAuthorizedRequestRS is the requestState-aware variant of
// processAuthorizedRequest. It reads SessionData fields from the captured
// snapshot in rs instead of calling session.GetX() (each of which acquires
// sd.sessionMutex.RLock — under Yaegi every RLock pays ~1-5ms of interpreter
// dispatch). Only session-mutating operations (Save, ResetRedirectCount,
// Clear, IsDirty) still go through the session pointer because those write
// state and have no snapshot.
func (t *TraefikOidc) processAuthorizedRequestRS(rw http.ResponseWriter, req *http.Request, rs *requestState) {
session := rs.session
redirectURL := rs.redirectURL
userIdentifier := rs.userIdentifier
if userIdentifier == "" {
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
// Check if session has been invalidated via backchannel or front-channel logout
idToken := rs.idToken
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
if idToken != "" {
sid, sub, createdAt := t.extractSessionInfo(idToken)
if t.isSessionInvalidated(sid, sub, createdAt) {
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing invalidated session: %v", err)
}
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
}
// Resolve ID-token claims at most once per request. SessionData caches
// the parsed claims keyed on the raw ID token.
var (
idClaims map[string]interface{}
idClaimsErr error
)
if idToken != "" {
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
}
var (
groupClaims map[string]interface{}
groupClaimsErr error
)
if idToken != "" {
groupClaims, groupClaimsErr = idClaims, idClaimsErr
} else if rs.accessToken != "" {
groupClaims, groupClaimsErr = t.extractClaimsFunc(rs.accessToken)
} else if len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
if groupClaimsErr != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract claims for roles/groups check: %v", groupClaimsErr)
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
// Persist any dirty session state BEFORE forwardAuthorized writes the
// response.
if session.IsDirty() {
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after processing headers: %v", err)
}
} else {
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
}
p := &principal{
Source: sourceSession,
Identifier: userIdentifier,
AccessToken: rs.accessToken,
IDToken: idToken,
RefreshToken: rs.refreshToken,
Claims: groupClaims,
}
t.forwardAuthorized(rw, req, p)
}
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
@@ -557,6 +676,44 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
//
// Session persistence is the CALLER's responsibility — it must happen before
// this function so Set-Cookie reaches the response.
// headerTemplateMaxLen bounds the length of a rendered operator-defined header
// template before it is forwarded downstream. Generous enough for an
// "Authorization: Bearer <jwt>" value but small enough to reject obviously
// abusive output. Matches the input-validation default header cap (8KB).
const headerTemplateMaxLen = 8192
// headerClaimMaxLen returns the maximum accepted length for a claim-derived
// header value (principal identifier, group, role). Reuses the operator-
// configured identifier cap (default 256) so a single setting governs both
// auth paths; falls back to 256 when unset.
func (t *TraefikOidc) headerClaimMaxLen() int {
if t.maxIdentifierLength > 0 {
return t.maxIdentifierLength
}
return 256
}
// sanitizeHeaderClaimList drops any group/role value that fails claim
// sanitization (control chars, bidi-override runes, the , ; = delimiters, or an
// over-long value) and returns the surviving values. Failing closed on a bad
// entry prevents header injection and stops an embedded comma from injecting
// extra entries into the comma-joined header. headerName is used only for
// debug logging — the value is never logged.
func (t *TraefikOidc) sanitizeHeaderClaimList(values []string, headerName string) []string {
if len(values) == 0 {
return nil
}
safe := make([]string, 0, len(values))
for _, v := range values {
if clean, ok := sanitizeHeaderClaimValue(v, t.headerClaimMaxLen()); ok {
safe = append(safe, clean)
} else {
t.logger.Debugf("Dropping %s entry: value failed claim sanitization", headerName)
}
}
return safe
}
func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Request, p *principal) {
var (
groups, roles []string
@@ -574,11 +731,18 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
return
}
if extractErr == nil {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
// Sanitize each group/role before it is joined into a comma-
// delimited header. The cookie/session path does not otherwise
// sanitize claim-derived values (the bearer path sanitizes its
// identifier at construction), so a control char would enable
// header injection and an embedded comma would inject extra
// entries into the comma-joined header. Fail closed: drop any
// value that does not pass.
if safeGroups := t.sanitizeHeaderClaimList(groups, "X-User-Groups"); len(safeGroups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(safeGroups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
if safeRoles := t.sanitizeHeaderClaimList(roles, "X-User-Roles"); len(safeRoles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(safeRoles, ","))
}
}
}
@@ -599,12 +763,26 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
}
}
req.Header.Set("X-Forwarded-User", p.Identifier)
// Sanitize the principal identifier before injecting it into headers. The
// bearer path already sanitizes its identifier at construction; the
// cookie/session path does not, so a claim carrying control chars, bidi-
// override runes, or , ; = could inject or spoof header content. Fail
// closed: drop the identifier header(s) rather than forward a tainted value.
safeIdentifier, identifierOK := sanitizeHeaderClaimValue(p.Identifier, t.headerClaimMaxLen())
if identifierOK {
req.Header.Set("X-Forwarded-User", safeIdentifier)
} else {
t.logger.Debugf("Dropping X-Forwarded-User header: identifier failed claim sanitization")
}
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", p.Identifier)
if identifierOK {
req.Header.Set("X-Auth-Request-User", safeIdentifier)
} else {
t.logger.Debugf("Dropping X-Auth-Request-User header: identifier failed claim sanitization")
}
if p.IDToken != "" {
req.Header.Set("X-Auth-Request-Token", p.IDToken)
}
@@ -629,8 +807,21 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
continue
}
headerValue := buf.String()
// Sanitize the rendered output: template inputs are claim-derived
// and attacker-influenceable, so reject control chars (header
// injection), bidi-override runes, the , ; = delimiters, and an
// over-long value. Fail closed by dropping the header rather than
// forwarding a tainted value. Do not log the value (it commonly
// carries the access token); log only name + reason.
if reason := headerValueReason(headerValue, headerTemplateMaxLen); reason != "" {
t.logger.Debugf("Dropping templated header %s: value failed sanitization (%s)", headerName, reason)
continue
}
req.Header.Set(headerName, headerValue)
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
// Do not log the value: templated headers commonly carry the access
// token (e.g. "Authorization: Bearer {{.AccessToken}}"), and logging
// it — even at debug — leaks credentials into logs.
t.logger.Debugf("Set templated header %s (%d bytes)", headerName, len(headerValue))
}
}
+14 -14
View File
@@ -13,8 +13,8 @@ func TestMiddlewareContextCancellation(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close to simulate waiting
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
}
// Create request with canceled context
@@ -39,8 +39,8 @@ func TestMiddlewareSessionErrorRecovery(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -73,8 +73,8 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -102,8 +102,8 @@ func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
logoutURLPath: "/logout",
postLogoutRedirectURI: "/",
forceHTTPS: false,
@@ -142,8 +142,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -187,8 +187,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -236,8 +236,8 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
+338 -198
View File
@@ -15,18 +15,28 @@ import (
// It implements request coalescing, rate limiting, and circuit breaking
// specifically for token refresh operations.
type RefreshCoordinator struct {
inFlightRefreshes map[string]*refreshOperation
cleanupTimers map[string]*time.Timer
sessionRefreshAttempts map[string]*refreshAttemptTracker
// inFlightRefreshes maps tokenHash -> *refreshOperation. sync.Map is used
// instead of a plain map + RWMutex so concurrent refreshes do not
// serialize on a single global lock. Under Yaegi the previous
// refreshMutex.Lock() was held for tens of milliseconds per request due
// to interpreter overhead on the work inside the critical section,
// causing dozens of goroutines to stack up on it and pin one CPU core.
inFlightRefreshes sync.Map
// sessionRefreshAttempts maps sessionID -> *refreshAttemptTracker.
// sync.Map + atomic tracker fields means isInCooldown/recordRefreshAttempt/
// recordRefreshSuccess/recordRefreshFailure are lock-free. Previously
// these used attemptsMutex sync.RWMutex; under Yaegi every Lock() acquisition
// adds 10-50ms of dispatch overhead, and they were called twice per leader
// request (once for recordRefreshAttempt, once for isInCooldown). That
// serializing pattern caused the v1.0.15 death spiral after v1.0.14
// removed the refreshMutex (same architectural shape, different mutex).
sessionRefreshAttempts sync.Map
circuitBreaker *RefreshCircuitBreaker
metrics *RefreshMetrics
logger *Logger
stopChan chan struct{}
config RefreshCoordinatorConfig
wg sync.WaitGroup
attemptsMutex sync.RWMutex
refreshMutex sync.RWMutex
cleanupTimerMu sync.Mutex
}
// RefreshCoordinatorConfig configures the refresh coordinator behavior
@@ -84,14 +94,46 @@ type refreshResult struct {
fromCache bool
}
// refreshAttemptTracker tracks refresh attempts for a session
type refreshAttemptTracker struct {
lastAttemptTime time.Time
windowStartTime time.Time
cooldownEndTime time.Time
// attemptState is the immutable snapshot of a session's refresh-attempt
// state. Lives behind refreshAttemptTracker.state (atomic.Value). Every
// transition (record, success, failure, window-reset, cooldown-enter,
// cooldown-exit) constructs a fresh attemptState and publishes it via
// CompareAndSwap so the entire field set is updated together.
//
// Per-field atomic.Load/Store (the previous v1.0.15 design) had a benign
// but observable hazard: the cooldown-exit reset wrote cooldownEndNano = 0
// first, then separately stored attempts = 1 and windowStartNano = now.
// A concurrent isInCooldown call could see cooldownEndNano = 0 (reset
// just completed) with attempts still at MaxRefreshAttempts, triggering
// a fresh cooldown immediately. The snapshot approach eliminates the
// intermediate state entirely.
type attemptState struct {
lastAttemptNano int64 // UnixNano of last attempt
windowStartNano int64 // UnixNano of attempt-window start
cooldownEndNano int64 // UnixNano; 0 = not in cooldown
attempts int32
consecutiveFailures int32
inCooldown bool
}
// refreshAttemptTracker tracks refresh attempts for a session via a single
// atomic.Value holding a *attemptState pointer. Readers do exactly one Load.
// Writers do Load → construct new → CompareAndSwap (retry on conflict).
// Under Yaegi this collapses 3-4 per-field atomic dispatches into one Load,
// and eliminates the cross-field race in the window-reset path.
type refreshAttemptTracker struct {
state atomic.Value // *attemptState
}
// stateOf returns the current attemptState, or a zero-value snapshot if none
// has been published yet. The empty snapshot represents "no attempts recorded".
func (t *refreshAttemptTracker) stateOf() *attemptState {
if v := t.state.Load(); v != nil {
s, _ := v.(*attemptState)
if s != nil {
return s
}
}
return &attemptState{}
}
// RefreshMetrics tracks coordinator performance metrics
@@ -106,14 +148,18 @@ type RefreshMetrics struct {
currentInFlightRefreshes int32
}
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh
// operations. All mutable fields are atomic so AllowRequest/RecordSuccess/
// RecordFailure run without any mutex. The previous sync.RWMutex.RLock() was
// taken on every CoordinateRefresh — under Yaegi this added 10-50ms of
// interpreter dispatch per call, which compounded with attemptsMutex to keep
// the pod's single CPU core saturated.
type RefreshCircuitBreaker struct {
lastFailureTime time.Time
lastSuccessTime time.Time
lastFailureNano int64 // atomic, UnixNano of most recent failure
lastSuccessNano int64 // atomic, UnixNano of most recent success
config RefreshCircuitBreakerConfig
mutex sync.RWMutex
state int32
failures int32
state int32 // atomic: 0=closed, 1=open, 2=half-open
failures int32 // atomic
}
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
@@ -130,13 +176,12 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
}
rc := &RefreshCoordinator{
inFlightRefreshes: make(map[string]*refreshOperation),
sessionRefreshAttempts: make(map[string]*refreshAttemptTracker),
config: config,
metrics: &RefreshMetrics{},
logger: logger,
stopChan: make(chan struct{}),
cleanupTimers: make(map[string]*time.Timer),
// inFlightRefreshes and sessionRefreshAttempts are both sync.Map;
// their zero values are ready to use.
config: config,
metrics: &RefreshMetrics{},
logger: logger,
stopChan: make(chan struct{}),
circuitBreaker: &RefreshCircuitBreaker{
config: RefreshCircuitBreakerConfig{
MaxFailures: 3,
@@ -227,13 +272,28 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
tokenHash string,
refreshToken string,
) (*refreshOperation, bool, error) {
rc.refreshMutex.Lock()
defer rc.refreshMutex.Unlock()
// Speculatively construct the operation we WOULD register if we win the
// race. Allocating here keeps the LoadOrStore call below atomic and
// avoids any global lock — under Yaegi the previous map+RWMutex design
// held the write lock long enough (tens of ms per call) that concurrent
// refreshes on the same coordinator serialized into a queue that grew
// without bound. See struct comment on inFlightRefreshes.
candidate := &refreshOperation{
refreshToken: refreshToken,
done: make(chan struct{}),
startTime: time.Now(),
waiterCount: 1,
}
// Check for existing operation while holding the lock
if existingOp, exists := rc.inFlightRefreshes[tokenHash]; exists {
if existing, loaded := rc.inFlightRefreshes.LoadOrStore(tokenHash, candidate); loaded {
existingOp, ok := existing.(*refreshOperation)
if !ok {
// Defensive: anything stored here is always *refreshOperation, but
// keep the typed assert so a programming error elsewhere doesn't
// surface as a confusing panic in an interpreter frame.
return nil, false, fmt.Errorf("inFlightRefreshes corrupt: unexpected type %T", existing)
}
if existingOp.refreshToken == refreshToken {
// Join existing operation
atomic.AddInt32(&existingOp.waiterCount, 1)
return existingOp, false, nil
}
@@ -241,41 +301,71 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
return nil, false, fmt.Errorf("refresh token mismatch")
}
// No existing operation - check if we can create a new one
// All checks happen while holding the lock to prevent races
// We won the race and registered `candidate`. Apply gates now. If any
// gate fails we must remove our entry from the map and signal failure
// to any joiners that snuck in between LoadOrStore and now.
if err := rc.applyLeaderGates(sessionID); err != nil {
rc.failCandidate(tokenHash, candidate, err)
return nil, false, err
}
// Check and record refresh attempt for rate limiting
rc.recordRefreshAttempt(sessionID)
// Reserve concurrent slot via ticket-and-return: increment optimistically,
// decrement if we overshot the limit. The previous CAS-loop allowed a
// transient overshoot of up to N-1 leaders when several goroutines all
// observed `current < max` in the same scheduling slice before any one
// of them succeeded their CAS — visible to readers as
// currentInFlightRefreshes > MaxConcurrentRefreshes for a brief window.
// The ticket pattern is strictly bounded: the counter momentarily reads
// max+k for k concurrent attempts past the limit, but only the k that
// produced max+1..max+k decrement back, and only k=1 ever observes max+1
// as committed.
newCount := atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
if int(newCount) > rc.config.MaxConcurrentRefreshes {
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
err := fmt.Errorf("maximum concurrent refresh operations reached")
rc.failCandidate(tokenHash, candidate, err)
return nil, false, err
}
return candidate, true, nil
}
// applyLeaderGates runs the rate-limit, cooldown, and memory-pressure checks
// that previously ran under the global refreshMutex. Only the leader (the
// goroutine that just registered the operation) runs them; joiners share the
// leader's outcome via operation.done.
func (rc *RefreshCoordinator) applyLeaderGates(sessionID string) error {
// Cooldown check FIRST, BEFORE incrementing the attempt counter.
// Previously this function recorded the attempt and then read the
// cooldown state. Under burst load (many concurrent leaders with
// different token hashes but same session) every goroutine could
// increment past MaxRefreshAttempts before any one of them observed
// the threshold, so the cooldown gate fired too late — the same
// thundering-herd shape that drove v1.0.14 into the ground.
if rc.isInCooldown(sessionID) {
atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1)
return nil, false, fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
return fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
}
// Check memory pressure
if rc.config.EnableMemoryPressureDetection && rc.isUnderMemoryPressure() {
atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1)
return nil, false, fmt.Errorf("system under memory pressure, refresh denied")
return fmt.Errorf("system under memory pressure, refresh denied")
}
// Only count attempts that actually progress past the gates.
rc.recordRefreshAttempt(sessionID)
return nil
}
// Check and reserve concurrent refresh slot atomically
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
if int(current) >= rc.config.MaxConcurrentRefreshes {
return nil, false, fmt.Errorf("maximum concurrent refresh operations reached")
}
// Reserve the slot - we're still holding the lock so this is safe
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
// Create and register new operation
operation := &refreshOperation{
refreshToken: refreshToken,
done: make(chan struct{}),
startTime: time.Now(),
waiterCount: 1,
}
rc.inFlightRefreshes[tokenHash] = operation
return operation, true, nil
// failCandidate removes the leader's just-registered operation from the
// in-flight map and signals the error to any joiners by recording the result
// and closing the done channel. This keeps the (nil, false, err) return path
// equivalent to the pre-sync.Map version: callers see the error directly,
// joiners see it via operation.done.
func (rc *RefreshCoordinator) failCandidate(tokenHash string, op *refreshOperation, err error) {
rc.inFlightRefreshes.Delete(tokenHash)
op.mutex.Lock()
op.result = &refreshResult{err: err}
op.mutex.Unlock()
close(op.done)
}
// executeRefreshAsync performs the actual refresh operation asynchronously
@@ -338,130 +428,196 @@ func (rc *RefreshCoordinator) executeRefreshAsync(
}
}
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning a goroutine
// This prevents goroutine explosion under high load (500+ req/sec)
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning
// a goroutine — time.AfterFunc uses the runtime's timer heap and never spawns
// a per-timer goroutine until the callback actually fires.
//
// The previous implementation tracked every pending timer in a map guarded by
// cleanupTimerMu so a duplicate scheduling could cancel the prior timer. That
// "shouldn't happen" path was the only consumer of the map, but the mutex
// fired on every successful refresh completion — yet another per-request
// Yaegi-dispatched lock acquisition. performCleanup is already idempotent
// (LoadAndDelete on the sync.Map), so a duplicate scheduling at worst fires
// performCleanup twice; the second call is a no-op. Dropping the map removes
// the whole class of contention on this code path.
func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) {
delay := rc.config.DeduplicationCleanupDelay
if delay <= 0 {
// Immediate cleanup
rc.performCleanup(tokenHash)
return
}
// Use time.AfterFunc which is more efficient than spawning a goroutine with Sleep
// time.AfterFunc uses the runtime's timer heap which is much more efficient
rc.cleanupTimerMu.Lock()
// Cancel any existing timer for this hash (shouldn't happen, but just in case)
if existingTimer, exists := rc.cleanupTimers[tokenHash]; exists {
existingTimer.Stop()
}
rc.cleanupTimers[tokenHash] = time.AfterFunc(delay, func() {
rc.performCleanup(tokenHash)
// Remove timer from map
rc.cleanupTimerMu.Lock()
delete(rc.cleanupTimers, tokenHash)
rc.cleanupTimerMu.Unlock()
})
rc.cleanupTimerMu.Unlock()
time.AfterFunc(delay, func() { rc.performCleanup(tokenHash) })
}
// performCleanup removes the operation from the in-flight map.
// Idempotent: only decrements the in-flight counter if an entry was actually
// removed. This guards against any future path accidentally calling cleanup
// twice for the same tokenHash (which would corrupt the refresh budget).
// removed. LoadAndDelete is atomic so any concurrent failCandidate or repeat
// cleanup call will see exactly one removal — the budget cannot be corrupted
// by double-decrement.
func (rc *RefreshCoordinator) performCleanup(tokenHash string) {
rc.refreshMutex.Lock()
_, existed := rc.inFlightRefreshes[tokenHash]
if existed {
delete(rc.inFlightRefreshes, tokenHash)
}
rc.refreshMutex.Unlock()
if existed {
if _, existed := rc.inFlightRefreshes.LoadAndDelete(tokenHash); existed {
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
}
}
// isInCooldown checks if a session is in cooldown after recording an attempt
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
// getOrCreateTracker fetches the tracker for sessionID or atomically creates a
// fresh one. The sync.Map.LoadOrStore semantics make this lock-free even under
// concurrent first-touch races: at most one tracker per sessionID survives.
//
// trackerFromMapValue centralizes the type assertion so the lint-mandated
// two-value form lives in one place; the stored type is always
// *refreshAttemptTracker by construction.
func trackerFromMapValue(v interface{}) *refreshAttemptTracker {
t, _ := v.(*refreshAttemptTracker)
return t
}
tracker, exists := rc.sessionRefreshAttempts[sessionID]
if !exists {
func (rc *RefreshCoordinator) getOrCreateTracker(sessionID string) *refreshAttemptTracker {
if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok {
return trackerFromMapValue(v)
}
fresh := &refreshAttemptTracker{}
fresh.state.Store(&attemptState{windowStartNano: time.Now().UnixNano()})
actual, _ := rc.sessionRefreshAttempts.LoadOrStore(sessionID, fresh)
return trackerFromMapValue(actual)
}
// mutateState performs a CompareAndSwap loop that applies mutate to the
// current snapshot. mutate must be PURE: it receives an immutable view of
// the current state and returns a fresh *attemptState. If mutate returns nil
// the update is skipped (used by isInCooldown for "no change needed" paths).
//
// Retries on CAS conflict are bounded by the number of concurrent writers —
// in practice 1-3. Under Yaegi each retry pays the dispatch cost of one Load
// + one CompareAndSwap; still cheaper than the previous per-field atomic
// sequence and immune to the cross-field race the v1.0.15 design had.
func (t *refreshAttemptTracker) mutateState(mutate func(cur *attemptState) *attemptState) *attemptState {
for {
cur := t.stateOf()
next := mutate(cur)
if next == nil {
return cur
}
if t.state.CompareAndSwap(cur, next) {
return next
}
}
}
// isInCooldown checks if a session is in cooldown. Snapshot-based: every
// transition publishes a fresh *attemptState atomically so readers never see
// a partially-updated state. The previous per-field atomic design had a
// benign race in the cooldown-exit path (cooldownEndNano reset before
// attempts reset) that could double-trigger cooldown.
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
if !ok {
return false // No tracker means first attempt, not in cooldown
}
tracker := trackerFromMapValue(v)
now := time.Now()
nowNano := now.UnixNano()
maxAttempts := rc.config.MaxRefreshAttempts
window := rc.config.RefreshAttemptWindow
cooldownPeriod := rc.config.RefreshCooldownPeriod
// Check if already in cooldown
if tracker.inCooldown {
if now.After(tracker.cooldownEndTime) {
// Cooldown expired, reset tracker
tracker.inCooldown = false
tracker.attempts = 1 // Already recorded one attempt
tracker.consecutiveFailures = 0
tracker.windowStartTime = now
return false
cur := tracker.stateOf()
// Already in cooldown?
if cur.cooldownEndNano != 0 {
if nowNano <= cur.cooldownEndNano {
return true // still in cooldown
}
return true // Still in cooldown
}
// Check if window expired
if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow {
// Reset window
tracker.attempts = 1 // Already recorded one attempt
tracker.windowStartTime = now
// Cooldown expired: atomically publish a fresh state with the window
// restarted from one attempt. Whichever goroutine wins the CAS sets
// the new snapshot; losers see it via the next stateOf load.
tracker.mutateState(func(s *attemptState) *attemptState {
if s.cooldownEndNano == 0 || nowNano <= s.cooldownEndNano {
return nil // someone else already reset, or back in cooldown
}
return &attemptState{
windowStartNano: nowNano,
attempts: 1,
}
})
return false
}
// Check if just exceeded attempt limit
if int(tracker.attempts) >= rc.config.MaxRefreshAttempts {
// Enter cooldown now
tracker.inCooldown = true
tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod)
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
sessionID, tracker.attempts)
// Window expired?
if time.Duration(nowNano-cur.windowStartNano) > window {
tracker.mutateState(func(s *attemptState) *attemptState {
if time.Duration(nowNano-s.windowStartNano) <= window {
return nil
}
next := *s
next.windowStartNano = nowNano
next.attempts = 1
return &next
})
return false
}
// Just exceeded attempt limit?
if int(cur.attempts) >= maxAttempts {
end := now.Add(cooldownPeriod).UnixNano()
published := tracker.mutateState(func(s *attemptState) *attemptState {
if s.cooldownEndNano != 0 {
return nil
}
next := *s
next.cooldownEndNano = end
return &next
})
if published.cooldownEndNano == end {
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
sessionID, published.attempts)
}
return true
}
return false
}
// recordRefreshAttempt records a refresh attempt for rate limiting
// recordRefreshAttempt records a refresh attempt for rate limiting. Lock-free
// snapshot mutation; attempts and lastAttemptNano are advanced atomically.
func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
tracker, exists := rc.sessionRefreshAttempts[sessionID]
if !exists {
tracker = &refreshAttemptTracker{
windowStartTime: time.Now(),
}
rc.sessionRefreshAttempts[sessionID] = tracker
}
atomic.AddInt32(&tracker.attempts, 1)
tracker.lastAttemptTime = time.Now()
tracker := rc.getOrCreateTracker(sessionID)
nowNano := time.Now().UnixNano()
tracker.mutateState(func(s *attemptState) *attemptState {
next := *s
next.attempts++
next.lastAttemptNano = nowNano
return &next
})
}
// recordRefreshSuccess records a successful refresh
// recordRefreshSuccess records a successful refresh: zero consecutiveFailures.
func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
tracker.consecutiveFailures = 0
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
if !ok {
return
}
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
if s.consecutiveFailures == 0 {
return nil
}
next := *s
next.consecutiveFailures = 0
return &next
})
}
// recordRefreshFailure records a failed refresh
// recordRefreshFailure records a failed refresh: increments consecutiveFailures.
func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
atomic.AddInt32(&tracker.consecutiveFailures, 1)
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
if !ok {
return
}
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
next := *s
next.consecutiveFailures++
return &next
})
}
// hashRefreshToken creates a hash of the refresh token for deduplication
@@ -512,20 +668,22 @@ func (rc *RefreshCoordinator) cleanupRoutine() {
}
}
// cleanupStaleEntries removes outdated tracking entries
// cleanupStaleEntries removes outdated tracking entries. Lock-free iteration
// via sync.Map.Range; safe to race with concurrent reads/writes.
func (rc *RefreshCoordinator) cleanupStaleEntries() {
now := time.Now()
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
// Clean up old session trackers
for sessionID, tracker := range rc.sessionRefreshAttempts {
// Remove trackers that haven't been used recently
if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow {
delete(rc.sessionRefreshAttempts, sessionID)
cutoff := time.Now().Add(-2 * rc.config.RefreshAttemptWindow).UnixNano()
rc.sessionRefreshAttempts.Range(func(key, value interface{}) bool {
tracker := trackerFromMapValue(value)
if tracker == nil {
return true
}
}
if tracker.stateOf().lastAttemptNano < cutoff {
// Compare-and-delete to avoid evicting a tracker that was just
// re-used by a concurrent caller. We compare by pointer identity.
rc.sessionRefreshAttempts.CompareAndDelete(key, value)
}
return true
})
}
// GetMetrics returns current coordinator metrics
@@ -543,78 +701,60 @@ func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} {
}
}
// Shutdown gracefully shuts down the coordinator
// Shutdown gracefully shuts down the coordinator. Pending delayed-cleanup
// timers are NOT canceled explicitly: time.AfterFunc callbacks are tiny
// (one map LoadAndDelete) and harmless after Shutdown — sync.Map operations
// remain safe on an unused coordinator until GC.
func (rc *RefreshCoordinator) Shutdown() {
close(rc.stopChan)
// Cancel all pending cleanup timers
rc.cleanupTimerMu.Lock()
for _, timer := range rc.cleanupTimers {
timer.Stop()
}
rc.cleanupTimers = make(map[string]*time.Timer)
rc.cleanupTimerMu.Unlock()
rc.wg.Wait()
}
// AllowRequest checks if the circuit breaker allows a request
// AllowRequest reports whether the circuit breaker allows a request. Lock-free.
func (cb *RefreshCircuitBreaker) AllowRequest() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
state := atomic.LoadInt32(&cb.state)
switch state {
case 0: // Closed
switch atomic.LoadInt32(&cb.state) {
case 0: // closed
return true
case 1: // Open
if time.Since(cb.lastFailureTime) > cb.config.OpenDuration {
// Try to transition to half-open
case 1: // open
lastFail := atomic.LoadInt64(&cb.lastFailureNano)
if time.Duration(time.Now().UnixNano()-lastFail) > cb.config.OpenDuration {
// Transition to half-open; first CAS winner gets the probe.
if atomic.CompareAndSwapInt32(&cb.state, 1, 2) {
return true
}
}
return false
case 2: // Half-open
case 2: // half-open
return true
default:
return false
}
}
// RecordSuccess records a successful operation
// RecordSuccess records a successful operation. Lock-free.
func (cb *RefreshCircuitBreaker) RecordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
state := atomic.LoadInt32(&cb.state)
if state == 2 { // Half-open
// Close the circuit
switch atomic.LoadInt32(&cb.state) {
case 2: // half-open -> close
atomic.StoreInt32(&cb.state, 0)
atomic.StoreInt32(&cb.failures, 0)
} else if state == 0 { // Closed
// Reset failure count on success
case 0: // closed
atomic.StoreInt32(&cb.failures, 0)
}
cb.lastSuccessTime = time.Now()
atomic.StoreInt64(&cb.lastSuccessNano, time.Now().UnixNano())
}
// RecordFailure records a failed operation
// RecordFailure records a failed operation. Lock-free.
func (cb *RefreshCircuitBreaker) RecordFailure() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
failures := atomic.AddInt32(&cb.failures, 1)
cb.lastFailureTime = time.Now()
atomic.StoreInt64(&cb.lastFailureNano, time.Now().UnixNano())
state := atomic.LoadInt32(&cb.state)
if state == 0 && int(failures) >= cb.config.MaxFailures {
// Open the circuit
atomic.StoreInt32(&cb.state, 1)
} else if state == 2 {
// Half-open failed, return to open
switch atomic.LoadInt32(&cb.state) {
case 0:
if int(failures) >= cb.config.MaxFailures {
atomic.StoreInt32(&cb.state, 1)
}
case 2:
// Half-open probe failed -> back to open.
atomic.StoreInt32(&cb.state, 1)
}
}
+28 -34
View File
@@ -165,9 +165,14 @@ func TestRefreshRateLimiting(t *testing.T) {
time.Sleep(150 * time.Millisecond)
}
// Verify that cooldown was triggered after max attempts
// With the new logic, the Nth attempt triggers cooldown, so we get N-1 successful attempts
expectedSuccessfulAttempts := config.MaxRefreshAttempts - 1
// Verify that cooldown was triggered after max attempts.
// With applyLeaderGates checking cooldown BEFORE recording the attempt
// (the v1.0.16 reorder fixing the thundering-herd off-by-one), N attempts
// run to completion and the (N+1)th is denied. Previously the Nth was
// denied as it tried to record, which under burst load let multiple
// concurrent leaders increment past the limit before any one of them
// observed the gate.
expectedSuccessfulAttempts := config.MaxRefreshAttempts
if attempts != expectedSuccessfulAttempts {
t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts)
}
@@ -365,10 +370,12 @@ func TestMemoryLeakPrevention(t *testing.T) {
}
}
// Verify cleanup is working
coordinator.attemptsMutex.RLock()
sessionCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
// Verify cleanup is working. sync.Map has no Len(); count via Range.
sessionCount := 0
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
sessionCount++
return true
})
// Should have cleaned up old sessions (only recent ones remain)
if sessionCount > numWorkers*2 {
@@ -650,24 +657,23 @@ func TestCleanupRoutine(t *testing.T) {
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
}
// Verify sessions exist
coordinator.attemptsMutex.RLock()
initialCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
countSessions := func() int {
n := 0
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
n++
return true
})
return n
}
if initialCount != 5 {
if initialCount := countSessions(); initialCount != 5 {
t.Errorf("Expected 5 sessions, got %d", initialCount)
}
// Wait for cleanup to run (2x window + cleanup interval)
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
// Verify sessions were cleaned up
coordinator.attemptsMutex.RLock()
finalCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
if finalCount != 0 {
if finalCount := countSessions(); finalCount != 0 {
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
}
}
@@ -720,11 +726,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
currentGoroutines := runtime.NumGoroutine()
t.Logf("Goroutines after %d refresh operations: %d", numRefreshes, currentGoroutines)
// Check timer count
coordinator.cleanupTimerMu.Lock()
timerCount := len(coordinator.cleanupTimers)
coordinator.cleanupTimerMu.Unlock()
t.Logf("Active cleanup timers: %d", timerCount)
// (Coordinator no longer tracks pending timers; time.AfterFunc closures
// fire performCleanup directly. This test now only checks the goroutine
// budget, which was always the real invariant.)
// With timer-based cleanup, goroutine increase should be minimal
// Timers don't create goroutines - they use the runtime timer heap
@@ -740,19 +744,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
initialGoroutines, currentGoroutines, goroutineIncrease)
}
// Wait for timers to fire and cleanup
// Wait for timers to fire and cleanup.
time.Sleep(config.DeduplicationCleanupDelay + 50*time.Millisecond)
// Verify timers were cleaned up
coordinator.cleanupTimerMu.Lock()
remainingTimers := len(coordinator.cleanupTimers)
coordinator.cleanupTimerMu.Unlock()
// Most timers should have fired and been removed
if remainingTimers > 10 {
t.Errorf("Too many cleanup timers remaining: %d", remainingTimers)
}
// Verify goroutines returned to near initial
runtime.GC()
time.Sleep(50 * time.Millisecond)
+71
View File
@@ -0,0 +1,71 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik.
// requestState bundles read-mostly fields for a single ServeHTTP call.
package traefikoidc
import "net/http"
// requestState is a per-request context object allocated at the top of
// ServeHTTP and threaded through to downstream handlers. It caches values
// that would otherwise require a Yaegi-dispatched lock acquisition each time
// they're read:
//
// - The metadata snapshot (atomic.Value.Load once, not per-handler).
// - SessionData getter results (one RLock on sd.sessionMutex covers all
// fields, instead of 5-7 separate RLock/RUnlock pairs scattered through
// the handler chain).
//
// The struct is alloc'd at request entry, populated under at most one RLock
// of sd.sessionMutex, and discarded at request exit. It is NOT shared across
// requests and never written from another goroutine, so no synchronization
// on its fields is required.
//
// Cross-request global caches (tokenCache, JWKCache, sessionEntries,
// sessionInvalidationCache) remain — they're orthogonal. requestState's job
// is to eliminate redundant per-handler reads of values that don't change
// within a single request.
type requestState struct {
// Globals snapshotted once.
metadata *MetadataSnapshot
// SessionData fields snapshotted under one RLock. The pointer to the
// SessionData is retained so handlers that genuinely need to mutate
// (Save, Clear, etc.) still have access.
session *SessionData
authenticated bool
accessToken string
idToken string
refreshToken string
userIdentifier string
createdAtUnixSec int64
// Output: scheme/host/redirect path determined at top of ServeHTTP.
scheme string
host string
redirectURL string
// Carry the next handler so forwardAuthorized doesn't need to close over t.
next http.Handler
}
// captureSession populates requestState's SessionData-derived fields under a
// single RLock of sd.sessionMutex. Returns the populated rs for chaining.
//
// Replaces a sequence of SessionData.GetX() calls each of which acquires
// sd.sessionMutex.RLock(). Under Yaegi each RLock costs ~1-5ms of
// interpreter dispatch; batching saves the rest.
func (rs *requestState) captureSession(sd *SessionData) *requestState {
if sd == nil {
return rs
}
rs.session = sd
sd.sessionMutex.RLock()
rs.authenticated = sd.getAuthenticatedUnsafe()
rs.accessToken = sd.getAccessTokenUnsafe()
rs.idToken = sd.getIDTokenUnsafe()
rs.refreshToken = sd.getRefreshTokenUnsafe()
rs.userIdentifier = sd.getUserIdentifierUnsafe()
rs.createdAtUnixSec = sd.getCreatedAtUnsafe()
sd.sessionMutex.RUnlock()
return rs
}
+404
View File
@@ -0,0 +1,404 @@
package traefikoidc
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/sessions"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// TestRank1_SessionCookieIsEncrypted verifies that the session cookie payload is
// AES-encrypted, not merely HMAC-signed. Regression test for the audit finding
// "session cookies signed but NOT encrypted": a single key left the stored OIDC
// tokens recoverable in plaintext from the raw cookie bytes.
func TestRank1_SessionCookieIsEncrypted(t *testing.T) {
const secret = "a-sufficiently-long-session-encryption-key"
authKey, encKey := deriveCookieKeys(secret)
if len(authKey) != 64 || len(encKey) != 32 {
t.Fatalf("expected 64-byte auth key and 32-byte enc key, got %d/%d", len(authKey), len(encKey))
}
if string(authKey) == string(encKey) {
t.Fatal("authentication and encryption keys must be independent")
}
const marker = "SUPER-SECRET-ACCESS-TOKEN-marker-value"
// Encode a session through the same two-key store the production code now
// builds (see NewSessionManager).
store := sessions.NewCookieStore(authKey, encKey)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
sess, err := store.New(req, "session")
if err != nil {
t.Fatalf("store.New failed: %v", err)
}
sess.Values["tok"] = marker
if err := sess.Save(req, rec); err != nil {
t.Fatalf("session save failed: %v", err)
}
var cookie *http.Cookie
for _, c := range rec.Result().Cookies() {
if c.Name == "session" {
cookie = c
}
}
if cookie == nil {
t.Fatal("no session cookie was set")
}
// The secret token must never appear in plaintext in the cookie value.
if strings.Contains(cookie.Value, marker) {
t.Error("marker token found in plaintext inside the session cookie value")
}
// A store holding only the authentication key (the previous behavior)
// must NOT be able to read the encrypted cookie — proving the payload is
// genuinely encrypted, not just signed.
signedOnly := sessions.NewCookieStore(authKey)
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
req2.AddCookie(cookie)
if _, derr := signedOnly.Get(req2, "session"); derr == nil {
t.Error("encrypted cookie should not be decodable without the encryption key")
}
// The full two-key store round-trips correctly.
req3 := httptest.NewRequest(http.MethodGet, "/", nil)
req3.AddCookie(cookie)
rt, derr := store.Get(req3, "session")
if derr != nil {
t.Fatalf("round-trip decode failed: %v", derr)
}
if got, _ := rt.Values["tok"].(string); got != marker {
t.Errorf("round-trip mismatch: got %q want %q", got, marker)
}
}
// TestRank2And6_InvalidConfigFailsClosed verifies that NewWithContext now calls
// Config.Validate() and fails closed on an empty or too-short session
// encryption key instead of silently substituting a public hardcoded key, and
// rejects other missing required fields. Regression test for "hardcoded default
// encryption key" + "Config.Validate() never called in production path".
func TestRank2And6_InvalidConfigFailsClosed(t *testing.T) {
base := func() *Config {
return &Config{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "this-is-a-valid-session-key-32b!",
RateLimit: 100,
}
}
// Sanity: a fully valid config still constructs.
p, err := NewWithContext(context.Background(), base(), nil, "valid")
if err != nil {
t.Fatalf("valid config should construct, got: %v", err)
}
if p != nil {
p.Close()
}
cases := []struct {
name string
mutate func(*Config)
}{
{"empty key", func(c *Config) { c.SessionEncryptionKey = "" }},
{"short key", func(c *Config) { c.SessionEncryptionKey = "tooshort" }},
{"missing providerURL", func(c *Config) { c.ProviderURL = "" }},
{"missing callbackURL", func(c *Config) { c.CallbackURL = "" }},
{"plaintext remote providerURL", func(c *Config) { c.ProviderURL = "http://accounts.google.com" }},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c := base()
tc.mutate(c)
plugin, err := NewWithContext(context.Background(), c, nil, tc.name)
if err == nil {
if plugin != nil {
plugin.Close()
}
t.Errorf("expected NewWithContext to reject config (%s), but it succeeded", tc.name)
}
})
}
}
// TestRank3_DiscoveredEndpointSSRFGuard verifies that endpoints from the
// provider discovery document are screened against SSRF targets before use.
func TestRank3_DiscoveredEndpointSSRFGuard(t *testing.T) {
tr := &TraefikOidc{}
blocked := []string{
"http://169.254.169.254/latest/meta-data/", // cloud metadata (link-local)
"http://[fe80::1]/jwks", // IPv6 link-local
"http://10.0.0.5/jwks", // private
"http://192.168.1.10/jwks", // private
"http://127.0.0.1/jwks", // loopback (allowLoopback=false)
"ftp://example.com/jwks", // disallowed scheme
}
for _, u := range blocked {
if err := tr.validateDiscoveredEndpoint(u, false); err == nil {
t.Errorf("expected discovered endpoint %q to be rejected", u)
}
}
allowed := []string{
"https://accounts.google.com/o/oauth2/v3/certs",
"https://www.googleapis.com/oauth2/v3/certs", // cross-domain JWKS must stay allowed
"", // empty optional endpoint
}
for _, u := range allowed {
if err := tr.validateDiscoveredEndpoint(u, false); err != nil {
t.Errorf("expected discovered endpoint %q to be allowed, got %v", u, err)
}
}
// Loopback is allowed only when the provider itself is loopback (dev/test).
if err := tr.validateDiscoveredEndpoint("http://127.0.0.1:8080/jwks", true); err != nil {
t.Errorf("loopback endpoint should be allowed when allowLoopback=true: %v", err)
}
// Private addresses are allowed when explicitly opted in.
trPriv := &TraefikOidc{allowPrivateIPAddresses: true}
if err := trPriv.validateDiscoveredEndpoint("http://10.0.0.5/jwks", false); err != nil {
t.Errorf("private endpoint should be allowed when allowPrivateIPAddresses=true: %v", err)
}
}
// TestRank4_IntrospectionHostPin verifies the host-equality check used to pin
// the credential-bearing introspection endpoint to the configured provider.
func TestRank4_IntrospectionHostPin(t *testing.T) {
if !sameHost("https://kc.example.com/realms/x", "https://kc.example.com/realms/x/protocol/openid-connect/token/introspect") {
t.Error("introspection on the same host as the provider should be accepted")
}
if sameHost("https://kc.example.com", "https://evil.example.net/introspect") {
t.Error("introspection on a different host must be rejected")
}
if sameHost("", "https://kc.example.com") || sameHost("https://kc.example.com", "") {
t.Error("empty URL must not be treated as a host match")
}
}
// TestRank5_OpenRedirectNeutralized verifies the helper the callback now applies
// to the stored incoming path forces a host-relative redirect target.
func TestRank5_OpenRedirectNeutralized(t *testing.T) {
cases := map[string]string{
"//evil.com/x": "/evil.com/x",
`/\evil.com`: "/evil.com",
"/legit/path": "/legit/path",
}
for in, want := range cases {
got := normalizeLogoutPath(in)
if got != want {
t.Errorf("normalizeLogoutPath(%q) = %q, want %q", in, got, want)
}
if strings.HasPrefix(got, "//") || strings.HasPrefix(got, `/\`) {
t.Errorf("normalizeLogoutPath(%q) = %q is still protocol-relative", in, got)
}
}
}
// TestRank14_ExcludedURLSegmentBoundary verifies excluded-URL matching is
// anchored at path-segment boundaries and cannot be widened into a bypass.
func TestRank14_ExcludedURLSegmentBoundary(t *testing.T) {
if !pathExcluded("/public", "/public") {
t.Error("exact match should be excluded")
}
if !pathExcluded("/public/page", "/public") {
t.Error("sub-path should be excluded")
}
if pathExcluded("/publicsecret", "/public") {
t.Error("/publicsecret must NOT be excluded by /public")
}
if pathExcluded("/public-admin", "/public") {
t.Error("/public-admin must NOT be excluded by /public")
}
if !pathExcluded("/health", "/health/") {
t.Error("trailing-slash config should still match the exact path")
}
if pathExcluded("/anything", "/") {
t.Error("root exclusion must not match arbitrary paths")
}
if !pathExcluded("/", "/") {
t.Error("root exclusion should match the root path")
}
}
// TestRank15_ForwardedHostSanitized verifies a crafted X-Forwarded-Host cannot
// inject CRLF, smuggle a second host, or otherwise poison the derived host.
func TestRank15_ForwardedHostSanitized(t *testing.T) {
mk := func(xfh string) *http.Request {
r := httptest.NewRequest(http.MethodGet, "http://real.example.com/x", nil)
r.Host = "real.example.com"
if xfh != "" {
r.Header.Set("X-Forwarded-Host", xfh)
}
return r
}
if got := utils.DetermineHost(mk("ext.example.com")); got != "ext.example.com" {
t.Errorf("clean X-Forwarded-Host should be honored, got %q", got)
}
if got := utils.DetermineHost(mk("a.example.com, evil.com")); got != "a.example.com" {
t.Errorf("multi-value X-Forwarded-Host should use first host only, got %q", got)
}
for _, bad := range []string{"evil.com\r\nSet-Cookie: x=1", "evil.com /x", " "} {
if got := utils.DetermineHost(mk(bad)); got != "real.example.com" {
t.Errorf("malformed X-Forwarded-Host %q should fall back to req.Host, got %q", bad, got)
}
}
}
// TestRank11_TransportPoolTLSIsolationAtLimit verifies that, once the client
// limit is reached, the transport pool reuses an existing transport only when
// its TLS settings match the caller's, and never hands back a transport built
// with different TLS trust settings.
func TestRank11_TransportPoolTLSIsolationAtLimit(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
maxClients: 5,
}
strict := DefaultHTTPClientConfig() // InsecureSkipVerify = false
t1 := pool.GetOrCreateTransport(strict)
if t1 == nil {
t.Fatal("expected a transport for the strict config")
}
// Saturate the client limit so subsequent calls hit the fallback path.
atomic.StoreInt32(&pool.clientCount, pool.maxClients)
// Same TLS settings, different (non-TLS) connection limit: safe to reuse.
sameTLS := DefaultHTTPClientConfig()
sameTLS.MaxConnsPerHost = 99
if got := pool.GetOrCreateTransport(sameTLS); got != t1 {
t.Error("at the limit a TLS-compatible config should reuse the existing transport")
}
// Different TLS settings (InsecureSkipVerify): must NOT reuse the strict
// transport — returning nil lets the caller fall back to a verifying default.
insecure := DefaultHTTPClientConfig()
insecure.InsecureSkipVerify = true
if got := pool.GetOrCreateTransport(insecure); got == t1 {
t.Error("at the limit a config with different TLS settings must not reuse the strict transport")
}
}
// TestRank9_RedisFingerprint verifies divergent explicit Redis backends produce
// distinct fingerprints (used to warn about ignored cache config), while an
// absent or disabled Redis yields the empty (no-warning) fingerprint.
func TestRank9_RedisFingerprint(t *testing.T) {
if redisFingerprint(nil) != "" {
t.Error("nil config should yield an empty fingerprint")
}
if redisFingerprint(&Config{}) != "" {
t.Error("config without Redis should yield an empty fingerprint")
}
if redisFingerprint(&Config{Redis: &RedisConfig{Enabled: false, Address: "a:6379"}}) != "" {
t.Error("disabled Redis should yield an empty fingerprint")
}
a := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "a:6379", KeyPrefix: "p"}})
b := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "b:6379", KeyPrefix: "p"}})
if a == "" || a == b {
t.Errorf("distinct enabled backends must produce distinct non-empty fingerprints (%q vs %q)", a, b)
}
}
// TestRank10_TokenTypeCacheKeyNoCollision verifies that two different tokens
// sharing the same 32-character JWT header prefix are classified independently.
// The previous 32-char cache key would have collided and mis-classified them.
func TestRank10_TokenTypeCacheKeyNoCollision(t *testing.T) {
tr := &TraefikOidc{
tokenTypeCache: NewCache(),
suppressDiagnosticLogs: true,
clientID: "client",
}
// A header prefix longer than 32 chars, shared by both tokens.
prefix := "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ"
idJWT := &JWT{Header: map[string]interface{}{}, Claims: map[string]interface{}{"nonce": "n"}}
accessJWT := &JWT{Header: map[string]interface{}{"typ": "at+jwt"}, Claims: map[string]interface{}{}}
if !tr.detectTokenType(idJWT, prefix+".id.sig") {
t.Error("token with a nonce claim should be detected as an ID token")
}
if tr.detectTokenType(accessJWT, prefix+".access.sig") {
t.Error("access token (typ=at+jwt) must not be mis-classified as ID despite the shared 32-char prefix")
}
}
// TestRank12_LiveInstanceCounter verifies the process-global instance counter
// that gates teardown of shared singleton tasks.
func TestRank12_LiveInstanceCounter(t *testing.T) {
start := atomic.LoadInt32(&liveInstanceCount)
registerLiveInstance()
registerLiveInstance()
if got := atomic.LoadInt32(&liveInstanceCount); got != start+2 {
t.Fatalf("expected %d live instances, got %d", start+2, got)
}
if rem := unregisterLiveInstance(); rem != start+1 {
t.Errorf("expected %d remaining, got %d", start+1, rem)
}
if rem := unregisterLiveInstance(); rem != start {
t.Errorf("expected %d remaining, got %d", start, rem)
}
}
// TestRank13_CookieMaxAgeMatchesSessionLifetime verifies the cookie store's
// MaxAge (which bounds both the cookie Max-Age and the codec's cryptographic
// timestamp validity) is bound to the configured session lifetime rather than
// gorilla's 30-day default.
func TestRank13_CookieMaxAgeMatchesSessionLifetime(t *testing.T) {
maxAge := 2 * time.Hour
sm, err := NewSessionManager(strings.Repeat("k", 40), false, "", "", maxAge, NewLogger("error"))
if err != nil {
t.Fatalf("NewSessionManager failed: %v", err)
}
defer sm.cancel()
cs, ok := sm.store.(*sessions.CookieStore)
if !ok {
t.Fatal("session store is not a *sessions.CookieStore")
}
if got := cs.Options.MaxAge; got != int(maxAge.Seconds()) {
t.Errorf("cookie store MaxAge = %d, want %d (bound to sessionMaxAge)", got, int(maxAge.Seconds()))
}
}
// TestRank33And34_HeaderSanitizationDistinction verifies the two header sinks
// use the right strictness: free-form templated header VALUES (rank 34) permit
// , ; = (e.g. an opaque "Bearer <token>" or an LDAP-DN claim) but reject CR/LF,
// bidi, and over-length; claim values joined into delimited/identifier headers
// (rank 33) additionally reject , ; =.
func TestRank33And34_HeaderSanitizationDistinction(t *testing.T) {
// Rank 34 — free-form header value.
if headerValueReason("Bearer abc=def==", 8192) != "" {
t.Error("'=' must be allowed in a free-form header value (opaque bearer token)")
}
if headerValueReason("cn=user,ou=eng;dc=x", 8192) != "" {
t.Error("',;=' must be allowed in a free-form header value (e.g. an LDAP DN claim)")
}
if headerValueReason("evil"+string(rune(13))+string(rune(10))+"Injected: 1", 8192) == "" {
t.Error("CR/LF must be rejected in a header value (injection)")
}
if headerValueReason("toolong", 3) == "" {
t.Error("over-length value must be rejected")
}
// Rank 33 — claim value bound for a delimited/identifier header.
if _, ok := sanitizeHeaderClaimValue("admins,superadmins", 256); ok {
t.Error("a comma must be rejected in a value joined into a comma-delimited header")
}
if _, ok := sanitizeHeaderClaimValue("normal-user@example.com", 256); !ok {
t.Error("a clean identifier must pass claim sanitization")
}
if _, ok := sanitizeHeaderClaimValue("evil"+string(rune(13))+string(rune(10))+"X: 1", 256); ok {
t.Error("CR/LF must be rejected in a claim value")
}
}
-590
View File
@@ -1,590 +0,0 @@
package traefikoidc
import (
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
)
// SecurityEventType categorizes different types of security events
// that can occur during OIDC authentication and authorization flows.
type SecurityEventType string
// Security event types for monitoring and alerting
const (
// AuthFailure indicates a failed authentication attempt
AuthFailure SecurityEventType = "authentication_failure"
// TokenValidFailure indicates JWT token validation failed
TokenValidFailure SecurityEventType = "token_validation_failure"
// RateLimitHit indicates rate limiting was triggered
RateLimitHit SecurityEventType = "rate_limit_hit"
// SuspiciousActivity indicates potentially malicious behavior
SuspiciousActivity SecurityEventType = "suspicious_activity"
)
// DefaultSeverity returns the default severity level for each security event type.
// Severity levels are: low, medium, high.
func (t SecurityEventType) DefaultSeverity() string {
switch t {
case AuthFailure:
return "medium"
case TokenValidFailure:
return "medium"
case RateLimitHit:
return "low"
case SuspiciousActivity:
return "high"
default:
return "medium"
}
}
// IPFailureType returns a string identifier for categorizing failures
// by IP address for rate limiting and blocking decisions.
func (t SecurityEventType) IPFailureType() string {
switch t {
case AuthFailure:
return "auth_failure"
case TokenValidFailure:
return "token_failure"
case SuspiciousActivity:
return "suspicious"
default:
return "general"
}
}
// SecurityEvent represents a security-related event with comprehensive context.
// Contains timing information, IP address, user agent, request details,
// and custom event-specific data for security analysis and alerting.
type SecurityEvent struct {
// Timestamp when the event occurred
Timestamp time.Time `json:"timestamp"`
// Details contains event-specific additional information
Details map[string]interface{} `json:"details,omitempty"`
// Type categorizes the event (auth_failure, token_failure, etc.)
Type string `json:"type"`
// Severity indicates event importance (low, medium, high)
Severity string `json:"severity"`
// ClientIP is the source IP address of the request
ClientIP string `json:"client_ip"`
// UserAgent is the User-Agent header from the request
UserAgent string `json:"user_agent"`
// RequestPath is the requested URL path
RequestPath string `json:"request_path"`
// Message provides human-readable description of the event
Message string `json:"message"`
}
// SecurityMonitor provides comprehensive security monitoring for the OIDC middleware.
// It tracks failures by IP address, detects suspicious patterns, enforces
// rate limits, and can trigger custom security event handlers.
type SecurityMonitor struct {
ipFailures map[string]*IPFailureTracker
patternDetector *SuspiciousPatternDetector
logger *Logger
cleanupTask *BackgroundTask
eventHandlers []SecurityEventHandler
config SecurityMonitorConfig
ipMutex sync.RWMutex
}
// IPFailureTracker maintains failure statistics and blocking state for an IP address.
// Used for implementing progressive penalties and automatic IP blocking based on
// failure patterns, with support for different failure types for
// rate limiting and IP blocking decisions.
type IPFailureTracker struct {
// LastFailure timestamp of the most recent failure
LastFailure time.Time
// FirstFailure timestamp of the first failure in current window
FirstFailure time.Time
// BlockedUntil indicates when the IP block expires
BlockedUntil time.Time
// FailureTypes tracks counts by failure type
FailureTypes map[string]int64
// FailureCount total number of failures
FailureCount int64
// mutex protects concurrent access to tracker data
mutex sync.RWMutex
// IsBlocked indicates if this IP is currently blocked
IsBlocked bool
}
// SuspiciousPatternDetector identifies attack patterns that may indicate coordinated threats.
// Analyzes events across multiple time windows to detect rapid failures, distributed attacks,
// and persistent attack patterns that individual IP monitoring might miss.
type SuspiciousPatternDetector struct {
// recentEvents stores recent security events for analysis
recentEvents []SecurityEvent
// shortWindow defines time frame for rapid failure detection
shortWindow time.Duration
// mediumWindow defines time frame for distributed attack detection
mediumWindow time.Duration
// longWindow defines time frame for persistent attack detection
longWindow time.Duration
// rapidFailureThreshold triggers rapid failure alerts
rapidFailureThreshold int
// distributedAttackThreshold triggers distributed attack alerts
distributedAttackThreshold int
// persistentAttackThreshold triggers persistent attack alerts
persistentAttackThreshold int
// eventsMutex protects concurrent access to events
eventsMutex sync.RWMutex
}
// SecurityEventHandler defines the interface for processing security events.
// Implementations can log events, send alerts, update external systems,
// or trigger automated response actions.
type SecurityEventHandler interface {
// HandleSecurityEvent processes a security event
HandleSecurityEvent(event SecurityEvent)
}
// SecurityMonitorConfig contains configuration parameters for the security monitor.
// Controls thresholds, time windows, and behavior for security monitoring.
type SecurityMonitorConfig struct {
// MaxFailuresPerIP sets the failure threshold before blocking
MaxFailuresPerIP int `json:"max_failures_per_ip"`
// FailureWindowMinutes defines the time window for counting failures
FailureWindowMinutes int `json:"failure_window_minutes"`
// BlockDurationMinutes sets how long to block an IP
BlockDurationMinutes int `json:"block_duration_minutes"`
// RapidFailureThreshold triggers rapid failure detection
RapidFailureThreshold int `json:"rapid_failure_threshold"`
// CleanupIntervalMinutes sets cleanup frequency for old data
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
RetentionHours int `json:"retention_hours"`
EnablePatternDetection bool `json:"enable_pattern_detection"`
EnableDetailedLogging bool `json:"enable_detailed_logging"`
LogSuspiciousOnly bool `json:"log_suspicious_only"`
}
// DefaultSecurityMonitorConfig returns a default configuration
func DefaultSecurityMonitorConfig() SecurityMonitorConfig {
return SecurityMonitorConfig{
MaxFailuresPerIP: 10,
FailureWindowMinutes: 15,
BlockDurationMinutes: 60,
EnablePatternDetection: true,
RapidFailureThreshold: 5,
EnableDetailedLogging: true,
LogSuspiciousOnly: false,
CleanupIntervalMinutes: 30,
RetentionHours: 24,
}
}
// NewSecurityMonitor creates a new security monitor instance
func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor {
sm := &SecurityMonitor{
ipFailures: make(map[string]*IPFailureTracker),
eventHandlers: make([]SecurityEventHandler, 0),
config: config,
logger: logger,
patternDetector: NewSuspiciousPatternDetector(),
}
sm.startCleanupRoutine()
return sm
}
// NewSuspiciousPatternDetector creates a new pattern detector
func NewSuspiciousPatternDetector() *SuspiciousPatternDetector {
return &SuspiciousPatternDetector{
shortWindow: 1 * time.Minute,
mediumWindow: 5 * time.Minute,
longWindow: 15 * time.Minute,
rapidFailureThreshold: 5,
distributedAttackThreshold: 20,
persistentAttackThreshold: 50,
recentEvents: make([]SecurityEvent, 0),
}
}
// RecordSecurityEvent is a generic method to record any type of security event
func (sm *SecurityMonitor) RecordSecurityEvent(
eventType SecurityEventType,
clientIP, userAgent, requestPath string,
message string,
details map[string]interface{},
trackIPFailure bool) {
event := SecurityEvent{
Type: string(eventType),
Severity: eventType.DefaultSeverity(),
Timestamp: time.Now(),
ClientIP: clientIP,
UserAgent: userAgent,
RequestPath: requestPath,
Message: message,
Details: details,
}
if trackIPFailure {
sm.recordIPFailure(clientIP, eventType.IPFailureType())
}
sm.processSecurityEvent(event)
}
// RecordAuthenticationFailure records an authentication failure event
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
if details == nil {
details = make(map[string]interface{})
}
details["reason"] = reason
sm.RecordSecurityEvent(
AuthFailure,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Authentication failed: %s", reason),
details,
true,
)
}
// RecordTokenValidationFailure records a token validation failure
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
details := map[string]interface{}{
"reason": reason,
}
if tokenPrefix != "" {
details["token_prefix"] = tokenPrefix
}
sm.RecordSecurityEvent(
TokenValidFailure,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Token validation failed: %s", reason),
details,
true,
)
}
// RecordRateLimitHit records when rate limiting is triggered
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
details := map[string]interface{}{
"limit_type": "token_verification",
}
sm.RecordSecurityEvent(
RateLimitHit,
clientIP,
userAgent,
requestPath,
"Rate limit exceeded",
details,
true,
)
}
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) {
if details == nil {
details = make(map[string]interface{})
}
details["activity_type"] = activityType
sm.RecordSecurityEvent(
SuspiciousActivity,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
details,
true,
)
}
// recordIPFailure tracks failures for a specific IP address
func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) {
sm.ipMutex.Lock()
defer sm.ipMutex.Unlock()
tracker, exists := sm.ipFailures[clientIP]
if !exists {
tracker = &IPFailureTracker{
FailureTypes: make(map[string]int64),
FirstFailure: time.Now(),
}
sm.ipFailures[clientIP] = tracker
}
tracker.mutex.Lock()
defer tracker.mutex.Unlock()
tracker.FailureCount++
tracker.LastFailure = time.Now()
tracker.FailureTypes[failureType]++
windowStart := time.Now().Add(-time.Duration(sm.config.FailureWindowMinutes) * time.Minute)
if tracker.FirstFailure.After(windowStart) && tracker.FailureCount >= int64(sm.config.MaxFailuresPerIP) {
if !tracker.IsBlocked {
tracker.IsBlocked = true
tracker.BlockedUntil = time.Now().Add(time.Duration(sm.config.BlockDurationMinutes) * time.Minute)
sm.logger.Errorf("IP %s blocked due to %d failures (types: %v)", clientIP, tracker.FailureCount, tracker.FailureTypes)
blockEvent := SecurityEvent{
Type: "ip_blocked",
Severity: "high",
Timestamp: time.Now(),
ClientIP: clientIP,
Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes),
Details: map[string]interface{}{
"failure_count": tracker.FailureCount,
"failure_types": tracker.FailureTypes,
"blocked_until": tracker.BlockedUntil,
},
}
sm.processSecurityEvent(blockEvent)
}
}
}
// IsIPBlocked checks if an IP address is currently blocked
func (sm *SecurityMonitor) IsIPBlocked(clientIP string) bool {
sm.ipMutex.RLock()
defer sm.ipMutex.RUnlock()
tracker, exists := sm.ipFailures[clientIP]
if !exists {
return false
}
tracker.mutex.RLock()
defer tracker.mutex.RUnlock()
if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) {
return true
}
if tracker.IsBlocked && time.Now().After(tracker.BlockedUntil) {
tracker.IsBlocked = false
sm.logger.Infof("IP %s automatically unblocked", clientIP)
}
return false
}
// processSecurityEvent processes a security event through all handlers and pattern detection
func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) {
if sm.config.EnablePatternDetection {
sm.patternDetector.AddEvent(event)
if patterns := sm.patternDetector.DetectSuspiciousPatterns(); len(patterns) > 0 {
if len(patterns) == 1 {
sm.logger.Errorf("Suspicious pattern detected: %s", patterns[0])
} else {
sm.logger.Errorf("Multiple suspicious patterns detected: %v", patterns)
}
for _, pattern := range patterns {
patternEvent := SecurityEvent{
Type: "suspicious_pattern",
Severity: "high",
Timestamp: time.Now(),
Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern),
Details: map[string]interface{}{
"pattern_type": pattern,
"trigger_event": event,
},
}
sm.handleSecurityEvent(patternEvent)
}
}
}
sm.handleSecurityEvent(event)
}
// handleSecurityEvent sends the event to all registered handlers
func (sm *SecurityMonitor) handleSecurityEvent(event SecurityEvent) {
if sm.config.EnableDetailedLogging && (!sm.config.LogSuspiciousOnly || event.Severity == "high") {
sm.logger.Infof("Security Event [%s/%s]: %s (IP: %s, Path: %s)",
event.Type, event.Severity, event.Message, event.ClientIP, event.RequestPath)
}
for _, handler := range sm.eventHandlers {
go handler.HandleSecurityEvent(event)
}
}
// AddEventHandler adds a security event handler
func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
sm.eventHandlers = append(sm.eventHandlers, handler)
}
// This is kept for API compatibility but doesn't collect actual metrics
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} {
return map[string]interface{}{
"tracked_ips": 0,
}
}
// AddEvent adds an event to the pattern detector
func (spd *SuspiciousPatternDetector) AddEvent(event SecurityEvent) {
spd.eventsMutex.Lock()
defer spd.eventsMutex.Unlock()
spd.recentEvents = append(spd.recentEvents, event)
cutoff := time.Now().Add(-spd.longWindow)
var filteredEvents []SecurityEvent
for _, e := range spd.recentEvents {
if e.Timestamp.After(cutoff) {
filteredEvents = append(filteredEvents, e)
}
}
spd.recentEvents = filteredEvents
}
// DetectSuspiciousPatterns analyzes recent events for suspicious patterns
func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string {
spd.eventsMutex.RLock()
defer spd.eventsMutex.RUnlock()
var patterns []string
now := time.Now()
ipCounts := make(map[string]int)
shortWindowStart := now.Add(-spd.shortWindow)
for _, event := range spd.recentEvents {
if event.Timestamp.After(shortWindowStart) &&
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
ipCounts[event.ClientIP]++
}
}
for ip, count := range ipCounts {
if count >= spd.rapidFailureThreshold {
patterns = append(patterns, fmt.Sprintf("rapid_failures_from_ip_%s", ip))
}
}
mediumWindowStart := now.Add(-spd.mediumWindow)
uniqueFailingIPs := make(map[string]bool)
for _, event := range spd.recentEvents {
if event.Timestamp.After(mediumWindowStart) &&
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
uniqueFailingIPs[event.ClientIP] = true
}
}
if len(uniqueFailingIPs) >= spd.distributedAttackThreshold {
patterns = append(patterns, "distributed_attack_pattern")
}
longWindowStart := now.Add(-spd.longWindow)
persistentFailures := 0
for _, event := range spd.recentEvents {
if event.Timestamp.After(longWindowStart) &&
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
persistentFailures++
}
}
if persistentFailures >= spd.persistentAttackThreshold {
patterns = append(patterns, "persistent_attack_pattern")
}
return patterns
}
// startCleanupRoutine starts the background cleanup routine
func (sm *SecurityMonitor) startCleanupRoutine() {
sm.cleanupTask = NewBackgroundTask(
"security-monitor-cleanup",
time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute,
sm.cleanup,
sm.logger)
sm.cleanupTask.Start()
}
// StopCleanupRoutine stops the background cleanup routine
func (sm *SecurityMonitor) StopCleanupRoutine() {
if sm.cleanupTask != nil {
sm.cleanupTask.Stop()
sm.cleanupTask = nil
}
}
// cleanup removes old tracking data
func (sm *SecurityMonitor) cleanup() {
sm.ipMutex.Lock()
defer sm.ipMutex.Unlock()
cutoff := time.Now().Add(-time.Duration(sm.config.RetentionHours) * time.Hour)
for ip, tracker := range sm.ipFailures {
tracker.mutex.RLock()
shouldRemove := tracker.LastFailure.Before(cutoff) && !tracker.IsBlocked
tracker.mutex.RUnlock()
if shouldRemove {
delete(sm.ipFailures, ip)
}
}
sm.logger.Debugf("Security monitor cleanup completed, tracking %d IPs", len(sm.ipFailures))
}
// ExtractClientIP extracts the client IP from the request, considering proxy headers
func ExtractClientIP(r *http.Request) string {
if xri := r.Header.Get("X-Real-IP"); xri != "" {
if net.ParseIP(xri) != nil {
return xri
}
}
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
if len(ips) > 0 {
ip := strings.TrimSpace(ips[0])
if net.ParseIP(ip) != nil {
return ip
}
}
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
// LoggingSecurityEventHandler logs security events to the standard logger
type LoggingSecurityEventHandler struct {
logger *Logger
}
// NewLoggingSecurityEventHandler creates a new logging event handler
func NewLoggingSecurityEventHandler(logger *Logger) *LoggingSecurityEventHandler {
return &LoggingSecurityEventHandler{logger: logger}
}
// HandleSecurityEvent implements SecurityEventHandler
func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
switch event.Severity {
case "high":
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
case "medium":
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
case "low":
h.logger.Infof("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
default:
h.logger.Debugf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
}
}
-285
View File
@@ -1,285 +0,0 @@
package traefikoidc
import (
"net/http/httptest"
"strconv"
"testing"
"time"
)
func TestSecurityMonitor(t *testing.T) {
config := DefaultSecurityMonitorConfig()
config.MaxFailuresPerIP = 3
config.BlockDurationMinutes = 1 // 1 minute for testing
config.CleanupIntervalMinutes = 1
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
defer func() {
// Allow cleanup goroutine to finish
time.Sleep(150 * time.Millisecond)
}()
t.Run("Record authentication failure", func(t *testing.T) {
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "invalid credentials", nil)
// Should not be blocked after first failure
if monitor.IsIPBlocked("192.168.1.1") {
t.Error("IP should not be blocked after first failure")
}
})
t.Run("IP blocked after max failures", func(t *testing.T) {
// Record multiple failures
for i := 0; i < config.MaxFailuresPerIP; i++ {
monitor.RecordAuthenticationFailure("192.168.1.2", "test-agent", "/login", "invalid credentials", nil)
}
// Should be blocked now
if !monitor.IsIPBlocked("192.168.1.2") {
t.Error("IP should be blocked after max failures")
}
})
t.Run("Token validation failure", func(t *testing.T) {
// Just verify the method doesn't panic
monitor.RecordTokenValidationFailure("192.168.1.3", "test-agent", "/api", "invalid token", "abc123")
})
t.Run("Rate limit hit", func(t *testing.T) {
// Just verify the method doesn't panic
monitor.RecordRateLimitHit("192.168.1.4", "test-agent", "/api")
})
t.Run("Suspicious activity", func(t *testing.T) {
details := map[string]interface{}{"pattern": "unusual"}
// Just verify the method doesn't panic
monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details)
})
}
func TestSuspiciousPatternDetector(t *testing.T) {
detector := NewSuspiciousPatternDetector()
t.Run("Add events and detect patterns", func(t *testing.T) {
// Add multiple events from same IP
for i := 0; i < 10; i++ {
event := SecurityEvent{
Type: "authentication_failure",
ClientIP: "192.168.1.100",
Timestamp: time.Now(),
}
detector.AddEvent(event)
}
patterns := detector.DetectSuspiciousPatterns()
found := false
for _, p := range patterns {
if p == "rapid_failures_from_ip_192.168.1.100" {
found = true
break
}
}
if !found {
t.Error("Expected to detect rapid failure pattern")
}
})
t.Run("Detect distributed attack pattern", func(t *testing.T) {
// Add failures from many different IPs
for i := 0; i < 25; i++ {
event := SecurityEvent{
Type: "authentication_failure",
ClientIP: "192.168.1." + strconv.Itoa(100+i),
Timestamp: time.Now(),
}
detector.AddEvent(event)
}
patterns := detector.DetectSuspiciousPatterns()
found := false
for _, p := range patterns {
if p == "distributed_attack_pattern" {
found = true
break
}
}
if !found {
t.Error("Expected to detect distributed attack pattern")
}
})
}
func TestExtractClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
headers map[string]string
expectedIP string
}{
{
name: "Direct connection",
remoteAddr: "192.168.1.1:12345",
expectedIP: "192.168.1.1",
},
{
name: "X-Forwarded-For header",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 10.0.0.1"},
expectedIP: "203.0.113.1",
},
{
name: "X-Real-IP header",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{"X-Real-IP": "203.0.113.2"},
expectedIP: "203.0.113.2",
},
{
name: "Multiple headers - X-Real-IP takes precedence",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Forwarded-For": "203.0.113.1",
"X-Real-IP": "203.0.113.2",
},
expectedIP: "203.0.113.2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
for key, value := range tt.headers {
req.Header.Set(key, value)
}
ip := ExtractClientIP(req)
if ip != tt.expectedIP {
t.Errorf("Expected IP %s, got %s", tt.expectedIP, ip)
}
})
}
}
func TestSecurityEventHandlers(t *testing.T) {
t.Run("Logging security event handler", func(t *testing.T) {
logger := NewLogger("debug")
handler := NewLoggingSecurityEventHandler(logger)
event := SecurityEvent{
Type: "authentication_failure",
ClientIP: "192.168.1.1",
Timestamp: time.Now(),
Message: "Test failure",
Severity: "medium",
}
// Should not panic
handler.HandleSecurityEvent(event)
})
// Metrics security event handler test removed as part of metrics cleanup
}
func TestSecurityMonitorEventHandlers(t *testing.T) {
config := DefaultSecurityMonitorConfig()
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
// Add event handler with proper synchronization
handlerCalled := make(chan bool, 1)
handler := &testSecurityEventHandler{
callback: func(event SecurityEvent) {
select {
case handlerCalled <- true:
default:
// Channel already has a value, don't block
}
},
}
monitor.AddEventHandler(handler)
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "test failure", nil)
// Wait for event handler to be called with timeout
select {
case <-handlerCalled:
// Success - handler was called
case <-time.After(100 * time.Millisecond):
t.Error("Expected event handler to be called within timeout")
}
}
// Test helper for security event handler
type testSecurityEventHandler struct {
callback func(SecurityEvent)
}
func (h *testSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
h.callback(event)
}
func TestDefaultSecurityMonitorConfig(t *testing.T) {
config := DefaultSecurityMonitorConfig()
if config.MaxFailuresPerIP <= 0 {
t.Error("Expected positive MaxFailuresPerIP")
}
if config.BlockDurationMinutes <= 0 {
t.Error("Expected positive BlockDurationMinutes")
}
if config.CleanupIntervalMinutes <= 0 {
t.Error("Expected positive CleanupIntervalMinutes")
}
if config.FailureWindowMinutes <= 0 {
t.Error("Expected positive FailureWindowMinutes")
}
}
func TestSecurityMonitorCleanup(t *testing.T) {
config := DefaultSecurityMonitorConfig()
config.CleanupIntervalMinutes = 1
config.BlockDurationMinutes = 1
config.RetentionHours = 1
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
// Block an IP
for i := 0; i < config.MaxFailuresPerIP; i++ {
monitor.RecordAuthenticationFailure("192.168.1.99", "test-agent", "/login", "test", nil)
}
// Verify it's blocked
if !monitor.IsIPBlocked("192.168.1.99") {
t.Error("IP should be blocked")
}
// Wait a bit and check if it gets unblocked automatically
time.Sleep(100 * time.Millisecond)
// The IP should still be blocked since we haven't waited long enough
if !monitor.IsIPBlocked("192.168.1.99") {
t.Error("IP should still be blocked")
}
}
func TestSecurityEventTypes(t *testing.T) {
config := DefaultSecurityMonitorConfig()
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
// Test different event types - just verify they don't panic
monitor.RecordAuthenticationFailure("192.168.1.200", "test-agent", "/login", "invalid password", nil)
monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123")
monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api")
details := map[string]interface{}{"pattern": "test"}
monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details)
// Just verify GetSecurityMetrics doesn't panic
_ = monitor.GetSecurityMetrics()
}
+62 -8
View File
@@ -4,7 +4,9 @@ import (
"bytes"
"compress/gzip"
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
@@ -31,6 +33,45 @@ func constantTimeStringCompare(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// deriveCookieKeys derives an independent 64-byte HMAC authentication key and a
// 32-byte AES-256 encryption key from the operator-provided session encryption
// key using HKDF-SHA256 (RFC 5869).
//
// gorilla/securecookie only ENCRYPTS the cookie payload when a block
// (encryption) key is supplied; constructing the store with a single key leaves
// sessions signed-but-plaintext, so the OIDC access/refresh/ID tokens stored in
// the cookie are recoverable by anyone who can read the raw cookie bytes. Two
// independent keys are derived here so the cookie is both encrypted and
// authenticated. HKDF is implemented with stdlib hmac+sha256 so it runs under
// Traefik's yaegi interpreter, which may not export crypto/hkdf.
func deriveCookieKeys(secret string) (authKey, encKey []byte) {
okm := hkdfSHA256([]byte(secret), nil, []byte("traefikoidc session cookie keys v1"), 96)
return okm[:64], okm[64:96]
}
// hkdfSHA256 performs HKDF-Extract followed by HKDF-Expand (RFC 5869) using
// HMAC-SHA256 and returns length bytes of output keying material.
func hkdfSHA256(ikm, salt, info []byte, length int) []byte {
if len(salt) == 0 {
salt = make([]byte, sha256.Size)
}
// Extract: PRK = HMAC-SHA256(salt, IKM)
ext := hmac.New(sha256.New, salt)
ext.Write(ikm)
prk := ext.Sum(nil)
// Expand: T(i) = HMAC-SHA256(PRK, T(i-1) | info | i)
var out, t []byte
for i := byte(1); len(out) < length; i++ {
exp := hmac.New(sha256.New, prk)
exp.Write(t)
exp.Write(info)
exp.Write([]byte{i})
t = exp.Sum(nil)
out = append(out, t...)
}
return out[:length]
}
// min returns the minimum of two integers.
// This is a utility function used throughout the session management code.
// Parameters:
@@ -118,12 +159,12 @@ var knownSessionKeys = map[string]bool{
"id_token": true,
"user_identifier": true,
"authenticated": true,
"csrf": true,
"nonce": true,
"code_verifier": true,
"incoming_path": true,
"created_at": true,
"redirect_count": true,
"csrf": true,
"nonce": true,
"code_verifier": true,
"incoming_path": true,
"created_at": true,
"redirect_count": true,
}
// compressCombinedPayload compresses the combined session payload using gzip.
@@ -423,8 +464,13 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
ctx, cancel := context.WithCancel(context.Background())
// Derive independent authentication + encryption keys so the session cookie
// is AES-256 encrypted and HMAC authenticated, not merely signed. See
// deriveCookieKeys: a single key would leave the stored tokens in plaintext.
authKey, encKey := deriveCookieKeys(encryptionKey)
sm := &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
store: sessions.NewCookieStore(authKey, encKey),
forceHTTPS: forceHTTPS,
cookieDomain: cookieDomain,
cookiePrefix: cookiePrefix,
@@ -435,6 +481,14 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
cancel: cancel,
}
// Bind the cookie codec's timestamp validity (and the cookie Max-Age) to the
// configured session lifetime instead of gorilla's 30-day default, so a
// stolen cookie is not cryptographically valid for up to 30 days regardless
// of the (possibly much shorter) configured sessionMaxAge (rank 13).
if cs, ok := sm.store.(*sessions.CookieStore); ok {
cs.MaxAge(int(sessionMaxAge.Seconds()))
}
// Initialize global memory monitoring (singleton)
sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger)
@@ -1566,7 +1620,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
sd.sessionMutex.Lock()
sd.clearAllSessionData(r, true)
// Release the lock before calling Save to prevent deadlock
sd.sessionMutex.Unlock()
+28 -1
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
@@ -54,6 +55,7 @@ type Config struct {
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedUsers []string `json:"allowedUsers"`
Headers []TemplatedHeader `json:"headers"`
ExtraAuthParams map[string]string `json:"extraAuthParams,omitempty"`
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
// a stored refresh token. Once the token has been in the session longer
@@ -761,7 +763,32 @@ func validateTemplateSecure(templateStr string) error {
// Returns true if the URL is valid and secure (HTTPS), false otherwise.
func isValidSecureURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme == "https" && u.Host != ""
if err != nil || u.Host == "" {
return false
}
if u.Scheme == "https" {
return true
}
// Permit plaintext HTTP only for loopback hosts (local development,
// in-cluster sidecar providers, tests). Loopback traffic never leaves the
// host, so it is not exposed to network MITM; remote providers must use
// HTTPS. Mirrors the RFC 8252 loopback allowance.
if u.Scheme == "http" && isLoopbackHost(u.Hostname()) {
return true
}
return false
}
// isLoopbackHost reports whether host is "localhost" or a loopback IP literal
// (127.0.0.0/8 or ::1).
func isLoopbackHost(host string) bool {
if strings.EqualFold(host, "localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil {
return ip.IsLoopback()
}
return false
}
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
+41 -11
View File
@@ -106,8 +106,9 @@ func (rm *ResourceManager) GetCache(key string) interface{} {
case "jwk-cache":
cache = cacheManager.GetSharedJWKCache()
default:
// Generic cache implementation
cache = NewGenericCache(1*time.Hour, rm.logger)
// Generic cache implementation; bind cleanup goroutine to the manager's
// shutdown channel so it exits when the ResourceManager shuts down.
cache = newGenericCacheWithOwner(1*time.Hour, rm.logger, rm.shutdownChan)
}
rm.caches[key] = cache
@@ -263,6 +264,19 @@ func (rm *ResourceManager) cleanupInstance(instanceID string) {
// This is a hook for future instance-specific cleanup needs
}
// liveInstanceCount tracks the number of fully-constructed TraefikOidc plugin
// instances alive in this process. Process-global singleton tasks (such as the
// shared token-cleanup) must only be stopped when the LAST instance shuts down,
// otherwise one instance's teardown would disable them for all survivors.
var liveInstanceCount int32
// registerLiveInstance records a newly constructed plugin instance.
func registerLiveInstance() { atomic.AddInt32(&liveInstanceCount, 1) }
// unregisterLiveInstance records a plugin instance shutting down and returns the
// number of instances still alive afterwards.
func unregisterLiveInstance() int32 { return atomic.AddInt32(&liveInstanceCount, -1) }
// Shutdown gracefully shuts down all managed resources
func (rm *ResourceManager) Shutdown(ctx context.Context) error {
var err error
@@ -501,20 +515,31 @@ func (p *GoroutinePool) Shutdown(ctx context.Context) error {
// GenericCache provides a simple cache implementation for testing
type GenericCache struct {
data map[string]interface{}
logger *Logger
stopChan chan struct{}
ttl time.Duration
mu sync.RWMutex
data map[string]interface{}
// ownerStopChan, when non-nil, signals the cleanup goroutine to exit when
// the owning ResourceManager shuts down, so the goroutine cannot outlive it.
ownerStopChan <-chan struct{}
logger *Logger
stopChan chan struct{}
ttl time.Duration
mu sync.RWMutex
}
// NewGenericCache creates a new generic cache
func NewGenericCache(ttl time.Duration, logger *Logger) *GenericCache {
return newGenericCacheWithOwner(ttl, logger, nil)
}
// newGenericCacheWithOwner creates a generic cache whose cleanup goroutine also
// exits when ownerStopChan is closed (typically the ResourceManager shutdown
// channel), guaranteeing the goroutine is stoppable on shutdown.
func newGenericCacheWithOwner(ttl time.Duration, logger *Logger, ownerStopChan <-chan struct{}) *GenericCache {
cache := &GenericCache{
data: make(map[string]interface{}),
ttl: ttl,
logger: logger,
stopChan: make(chan struct{}),
data: make(map[string]interface{}),
ttl: ttl,
logger: logger,
stopChan: make(chan struct{}),
ownerStopChan: ownerStopChan,
}
// Start cleanup routine
@@ -570,6 +595,11 @@ func (gc *GenericCache) cleanupRoutine() {
gc.mu.Unlock()
case <-gc.stopChan:
return
case <-gc.ownerStopChan:
// Owning ResourceManager is shutting down; exit so the goroutine
// does not outlive its owner. A nil channel blocks forever, so this
// case is inert when no owner is set.
return
}
}
}
+27 -15
View File
@@ -296,9 +296,12 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
// Create a TraefikOidc instance with context
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
RateLimit: 100,
}
plugin, err := NewWithContext(ctx, config, nil, "test")
@@ -350,9 +353,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
initialGoroutines := runtime.NumGoroutine()
configs := []Config{
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
}
var plugins []*TraefikOidc
@@ -432,9 +435,12 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
for i := 0; i < 3; i++ {
ctx := context.Background()
config := &Config{
ProviderURL: mockServers[i].URL,
ClientID: fmt.Sprintf("client%d", i),
ClientSecret: fmt.Sprintf("secret%d", i),
ProviderURL: mockServers[i].URL,
ClientID: fmt.Sprintf("client%d", i),
ClientSecret: fmt.Sprintf("secret%d", i),
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
RateLimit: 100,
}
plugin, err := NewWithContext(ctx, config, nil, fmt.Sprintf("test-%d", i))
@@ -595,9 +601,12 @@ func TestBackwardCompatibility(t *testing.T) {
t.Run("LegacyNewFunction", func(t *testing.T) {
// Test that the original New function still works
config := &Config{
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
RateLimit: 100,
}
handler, err := New(context.Background(), nil, config, "test")
@@ -617,9 +626,12 @@ func TestBackwardCompatibility(t *testing.T) {
t.Run("ExistingAPICompatibility", func(t *testing.T) {
config := &Config{
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
RateLimit: 100,
}
handler, _ := New(context.Background(), nil, config, "test")
+29 -14
View File
@@ -21,13 +21,16 @@ type IntrospectionResponse struct {
Username string `json:"username,omitempty"`
TokenType string `json:"token_type,omitempty"`
Sub string `json:"sub,omitempty"`
Aud string `json:"aud,omitempty"`
Iss string `json:"iss,omitempty"`
Jti string `json:"jti,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Nbf int64 `json:"nbf,omitempty"`
Active bool `json:"active"`
// Aud holds the introspection audience. Per RFC 7662 it may be a single
// string or an array of strings, so it is decoded as interface{} and
// matched with verifyAudience (which handles both shapes).
Aud interface{} `json:"aud,omitempty"`
Iss string `json:"iss,omitempty"`
Jti string `json:"jti,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Nbf int64 `json:"nbf,omitempty"`
Active bool `json:"active"`
}
// introspectToken performs OAuth 2.0 Token Introspection (RFC 7662) for an opaque token.
@@ -120,7 +123,7 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err
// Parse response per RFC 7662 Section 2.2
var introspectionResp IntrospectionResponse
if err := json.NewDecoder(resp.Body).Decode(&introspectionResp); err != nil {
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&introspectionResp); err != nil {
return nil, fmt.Errorf("failed to decode introspection response: %w", err)
}
@@ -128,6 +131,12 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err
if t.introspectionCache != nil {
// Cache for a short duration or until token expiry (whichever is shorter)
cacheDuration := 5 * time.Minute
// When introspection is REQUIRED, operators expect near-real-time
// revocation; cap the positive-result cache so a token revoked at the
// provider cannot keep passing for the full 5 minutes (rank 8).
if t.requireTokenIntrospection && cacheDuration > 30*time.Second {
cacheDuration = 30 * time.Second
}
if introspectionResp.Exp > 0 {
expTime := time.Unix(introspectionResp.Exp, 0)
untilExp := time.Until(expTime)
@@ -197,12 +206,18 @@ func (t *TraefikOidc) validateOpaqueToken(token string) error {
}
}
// Validate audience if configured
// Note: For opaque tokens, audience validation via introspection may be limited
// depending on what the introspection endpoint returns
if t.audience != "" && t.audience != t.clientID && resp.Aud != "" {
if resp.Aud != t.audience {
return fmt.Errorf("invalid audience: expected %s, got %s", t.audience, resp.Aud)
// Validate audience if configured. When a distinct API audience is
// configured (audience != clientID), the introspection response MUST carry
// a matching audience. Fail closed on a missing or mismatched aud: a token
// whose audience cannot be confirmed must not be accepted, otherwise a
// token minted for a different audience would pass. aud may be a single
// string or an array of strings (RFC 7662); verifyAudience handles both.
if t.audience != "" && t.audience != t.clientID {
if resp.Aud == nil {
return fmt.Errorf("invalid audience: expected %s, introspection response has no audience", t.audience)
}
if err := verifyAudience(resp.Aud, t.audience); err != nil {
return fmt.Errorf("invalid audience: expected %s: %w", t.audience, err)
}
}
+9 -439
View File
@@ -5,8 +5,8 @@ package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
@@ -214,11 +214,13 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
//
//nolint:gocognit,gocyclo // Complex token type detection with multiple provider-specific checks
func (t *TraefikOidc) detectTokenType(jwt *JWT, token string) bool {
// Use first 32 chars of token as cache key (sufficient for uniqueness)
cacheKey := token
if len(token) > 32 {
cacheKey = token[:32]
}
// Key on a hash of the FULL token. The first 32 characters of a JWT are
// only the base64url-encoded header, which is identical for every token
// sharing the same alg+kid, so distinct tokens (e.g. an ID token and an
// access token from the same issuer) would otherwise collide on the cache
// key and be mis-classified.
sum := sha256.Sum256([]byte(token))
cacheKey := hex.EncodeToString(sum[:])
// Check cache first
if t.tokenTypeCache != nil {
@@ -860,438 +862,6 @@ func (t *TraefikOidc) isAzureProvider() bool {
strings.Contains(issuerURL, "login.windows.net")
}
// validateAzureTokens validates tokens with Azure AD-specific logic.
// Azure tokens may be opaque access tokens that cannot be verified as JWTs,
// so this method handles both JWT and opaque token scenarios.
// Parameters:
// - session: The session data containing tokens to validate.
//
// Returns:
// - authenticated: Whether the user has valid authentication.
// - needsRefresh: Whether tokens need to be refreshed.
// - expired: Whether tokens have expired and cannot be refreshed.
//
//nolint:gocognit // Azure-specific validation requires multiple token type checks
func (t *TraefikOidc) validateAzureTokens(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("Azure user is not authenticated according to session flag")
if session.GetRefreshToken() != "" {
t.logger.Debug("Azure session not authenticated, but refresh token exists. Signaling need for refresh.")
return false, true, false
}
return false, true, false
}
accessToken := session.GetAccessToken()
idToken := session.GetIDToken()
if accessToken != "" {
if strings.Count(accessToken, ".") == 2 {
// Microsoft documents that client apps cannot validate access
// tokens issued for Microsoft-owned APIs (Graph, Azure Mgmt) due
// to their proprietary signing format (nonce in JWT header is
// the marker — signed bytes hash the nonce, wire bytes ship the
// raw value, so rsa verification always fails). Treat such
// tokens as opaque, matching Microsoft's guidance and avoiding
// per-request signature-error log spam (issue #134 followup).
//
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
// "you can't validate tokens for Microsoft Graph according to
// these rules due to their proprietary format"
if t.isUnverifiableAzureAccessToken(accessToken) {
t.logger.Debug("Azure access token is Microsoft-proprietary (Graph/Mgmt) — treating as opaque per Microsoft guidance")
if idToken != "" {
if err := t.verifyToken(idToken); err != nil {
t.logger.Debugf("Azure: ID token validation failed while access token was opaque: %v", err)
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiry(session, idToken)
}
return true, false, false
}
if err := t.verifyToken(accessToken); err != nil {
if idToken != "" {
if err := t.verifyToken(idToken); err != nil {
t.logger.Debugf("Azure: Both access and ID token validation failed: %v", err)
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiry(session, idToken)
}
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiry(session, accessToken)
}
t.logger.Debug("Azure access token appears opaque, treating as valid")
if idToken != "" {
return t.validateTokenExpiry(session, idToken)
}
return true, false, false
}
if idToken != "" {
if err := t.verifyToken(idToken); err != nil {
if strings.Contains(err.Error(), "token has expired") {
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiry(session, idToken)
}
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
// validateGoogleTokens handles Google-specific token validation logic.
// Currently delegates to standard token validation but provides a hook
// for Google-specific validation requirements in the future.
// Parameters:
// - session: The session data containing tokens to validate.
//
// Returns:
// - authenticated: Whether the user has valid authentication.
// - needsRefresh: Whether tokens need to be refreshed.
// - expired: Whether tokens have expired and cannot be refreshed.
func (t *TraefikOidc) validateGoogleTokens(session *SessionData) (bool, bool, bool) {
return t.validateStandardTokens(session)
}
// validateStandardTokens handles standard OIDC token validation logic.
// This is the default validation method for generic OIDC providers.
// It verifies ID tokens and handles access tokens appropriately.
// Parameters:
// - session: The session data containing tokens to validate.
//
// Returns:
// - authenticated: Whether the user has valid authentication.
// - needsRefresh: Whether tokens need to be refreshed.
// - expired: Whether tokens have expired and cannot be refreshed.
//
//nolint:gocognit,gocyclo // Complex validation logic handles multiple token scenarios and edge cases
func (t *TraefikOidc) validateStandardTokens(session *SessionData) (bool, bool, bool) {
authenticated := session.GetAuthenticated()
// Removed debug output
if !authenticated {
t.logger.Debug("User is not authenticated according to session flag")
if session.GetRefreshToken() != "" {
t.logger.Debug("Session not authenticated, but refresh token exists. Signaling need for refresh.")
return false, true, false
}
return false, false, false
}
accessToken := session.GetAccessToken()
// Removed debug output
if accessToken == "" {
t.logger.Debug("Authenticated flag set, but no access token found in session")
if session.GetRefreshToken() != "" {
// Check if we have an ID token to determine if we're beyond grace period
// When access token is missing, check ID token expiry to determine if refresh is viable
idToken := session.GetIDToken()
t.logger.Debugf("Checking ID token for grace period: ID token present: %v", idToken != "")
if idToken != "" {
// Try to parse the ID token to check its expiry
parts := strings.Split(idToken, ".")
if len(parts) == 3 {
// Decode the claims part
claimsData, err := base64.RawURLEncoding.DecodeString(parts[1])
if err == nil {
var claims map[string]interface{}
if err := json.Unmarshal(claimsData, &claims); err == nil {
if expClaim, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(expClaim), 0)
if time.Now().After(expTime) {
expiredDuration := time.Since(expTime)
if expiredDuration > t.refreshGracePeriod {
t.logger.Debugf("ID token expired beyond grace period (%v > %v), must re-authenticate",
expiredDuration, t.refreshGracePeriod)
return false, false, true // expired, cannot refresh
}
t.logger.Debugf("ID token expired %v ago, within grace period %v, allowing refresh",
expiredDuration, t.refreshGracePeriod)
}
}
}
}
}
}
t.logger.Debug("Access token missing, but refresh token exists. Signaling need for refresh.")
return false, true, false
}
return false, false, true
}
// Check if access token is opaque (doesn't have JWT structure)
dotCount := strings.Count(accessToken, ".")
isOpaqueToken := dotCount != 2
// For opaque access tokens, use introspection if available (RFC 7662 - Option C: Scenario 3)
if isOpaqueToken {
t.logger.Debugf("Access token appears to be opaque (dots: %d)", dotCount)
// Try introspection first if opaque tokens are allowed
if t.allowOpaqueTokens {
if err := t.validateOpaqueToken(accessToken); err != nil {
errMsg := err.Error()
t.logger.Infof("⚠️ Opaque access token validation via introspection failed: %v", err)
// Check if the token was explicitly marked as inactive/revoked/expired by the provider
// In these cases, we should NOT fall back to ID token - the provider has explicitly
// told us this token is no longer valid. We must refresh or re-authenticate.
isTokenInvalid := strings.Contains(errMsg, "token is not active") ||
strings.Contains(errMsg, "revoked") ||
strings.Contains(errMsg, "token has expired")
if isTokenInvalid {
t.logger.Infof("⚠️ Token explicitly marked as invalid by provider, cannot fall back to ID token")
if session.GetRefreshToken() != "" {
t.logger.Debug("Refresh token available, attempting refresh")
return false, true, false
}
t.logger.Debug("No refresh token available, must re-authenticate")
return false, false, true
}
// If introspection required, reject the session
if t.requireTokenIntrospection {
t.logger.Errorf("❌ SECURITY: Opaque token rejected (introspection required but failed)")
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
// Only fall back to ID token validation for transient errors (network issues, etc.)
// where the introspection endpoint couldn't be reached
t.logger.Infof("⚠️ Falling back to ID token validation for opaque access token (transient error)")
} else {
// Introspection successful
t.logger.Debugf("✓ Opaque access token validated via introspection")
// Still need to check ID token for session expiry
idToken := session.GetIDToken()
if idToken != "" {
return t.validateTokenExpiry(session, idToken)
}
return true, false, false
}
} else {
// Opaque tokens not allowed - log warning and reject or fall back
t.logger.Infof("⚠️ Opaque access token detected but allowOpaqueTokens=false")
}
// Fall back to ID token validation
idToken := session.GetIDToken()
if idToken == "" {
t.logger.Debug("Opaque access token present but no ID token found")
if session.GetRefreshToken() != "" {
t.logger.Debug("ID token missing but refresh token exists. Signaling need for refresh.")
return false, true, false
}
// Accept session with opaque access token even without ID token
// The OAuth provider validated it when issued
t.logger.Debug("Accepting session with opaque access token")
return true, false, false
}
// Validate ID token if present
if err := t.verifyToken(idToken); err != nil {
if strings.Contains(err.Error(), "token has expired") {
t.logger.Debugf("ID token expired with opaque access token, needs refresh")
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
t.logger.Errorf("ID token verification failed with opaque access token: %v", err)
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
// Use ID token for expiry validation
return t.validateTokenExpiry(session, idToken)
}
// JWT access token present - validate it explicitly to detect Scenario 2
// (Option C: Scenario 2 detection and strict mode)
accessTokenValid := false
accessTokenError := ""
if err := t.verifyToken(accessToken); err != nil {
// Access token validation failed
accessTokenError = err.Error()
// Check if it's an audience validation failure (Scenario 2)
if strings.Contains(accessTokenError, "invalid audience") || strings.Contains(accessTokenError, "audience") {
// SCENARIO 2 DETECTED: Access token has wrong audience
t.logger.Infof("⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch: %v", err)
if t.strictAudienceValidation {
// Strict mode: Reject the session (don't fall back to ID token)
t.logger.Errorf("❌ SECURITY: Session rejected due to access token audience mismatch (strictAudienceValidation=true)")
t.logger.Errorf("❌ This prevents potential cross-API token confusion attacks (Auth0 Scenario 2)")
if session.GetRefreshToken() != "" {
return false, true, false // try refresh
}
return false, false, true // must re-authenticate
}
// Backward compatibility mode: Log loud warning but allow fallback to ID token
t.logger.Infof("⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!")
t.logger.Infof("⚠️ This could allow tokens intended for different APIs to grant access")
t.logger.Infof("⚠️ Set strictAudienceValidation=true to enforce proper audience validation")
t.logger.Infof("⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74")
} else if !strings.Contains(accessTokenError, "token has expired") {
// Other validation errors (not expiration, not audience)
t.logger.Debugf("Access token validation failed (non-expiration, non-audience): %v", err)
}
} else {
// Access token is valid
accessTokenValid = true
}
idToken := session.GetIDToken()
if idToken == "" {
if accessTokenValid {
// Access token is valid, no ID token needed
t.logger.Debug("Access token valid, no ID token present")
return t.validateTokenExpiry(session, accessToken)
}
t.logger.Debug("Authenticated flag set with access token, but no ID token found in session")
if session.GetRefreshToken() != "" {
t.logger.Debug("ID token missing but refresh token exists. Signaling conditional refresh to obtain ID token.")
return true, true, false
}
return true, false, false
}
// Validate ID token
if err := t.verifyToken(idToken); err != nil {
if strings.Contains(err.Error(), "token has expired") {
t.logger.Debugf("ID token signature/claims valid but token expired, needs refresh")
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
t.logger.Errorf("ID token verification failed (non-expiration): %v", err)
if session.GetRefreshToken() != "" {
t.logger.Debug("ID token verification failed, but refresh token exists. Signaling need for refresh.")
return false, true, false
}
return false, false, true
}
// If access token was valid, use it for expiry; otherwise use ID token
if accessTokenValid {
return t.validateTokenExpiry(session, accessToken)
}
return t.validateTokenExpiry(session, idToken)
}
// validateTokenExpiry checks if a token is nearing expiration and needs refresh.
// It uses the configured grace period to determine when proactive refresh should occur.
// Parameters:
// - session: The session data for refresh token availability.
// - token: The token to check expiry for.
//
// Returns:
// - authenticated: Whether the token is currently valid.
// - needsRefresh: Whether the token is nearing expiration and should be refreshed.
// - expired: Whether the token is invalid or verification failed.
func (t *TraefikOidc) validateTokenExpiry(session *SessionData, token string) (bool, bool, bool) {
cachedClaims, found := t.tokenCache.Get(token)
if !found {
t.logger.Debug("Claims not found in cache after successful token verification")
if session.GetRefreshToken() != "" {
t.logger.Debug("Claims missing post-verification, attempting refresh to recover.")
return false, true, false
}
return false, false, true
}
expClaim, ok := cachedClaims["exp"].(float64)
if !ok {
t.logger.Error("Failed to get expiration time ('exp' claim) from verified token")
if session.GetRefreshToken() != "" {
t.logger.Debug("Token missing 'exp' claim, but refresh token exists. Signaling need for refresh.")
return false, true, false
}
return false, false, true
}
expTime := int64(expClaim)
expTimeObj := time.Unix(expTime, 0)
nowObj := time.Now()
// Check if token has already expired
if expTimeObj.Before(nowObj) {
// Token has expired
expiredDuration := nowObj.Sub(expTimeObj)
t.logger.Debugf("Token expired %v ago, grace period is %v",
expiredDuration, t.refreshGracePeriod)
// If we have a refresh token, always attempt to use it regardless of grace period
// The refresh token has its own expiry and the provider will reject it if invalid
if session.GetRefreshToken() != "" {
t.logger.Debugf("Token expired, attempting refresh with available refresh token")
return false, true, false // needs refresh
}
// No refresh token available - must re-authenticate
t.logger.Debugf("Token expired and no refresh token available, must re-authenticate")
return false, false, true // expired, cannot refresh
}
// Token not yet expired - check if nearing expiration
refreshThreshold := nowObj.Add(t.refreshGracePeriod)
t.logger.Debugf("Token expires at %v, now is %v, refresh threshold is %v",
expTimeObj.Format(time.RFC3339),
nowObj.Format(time.RFC3339),
refreshThreshold.Format(time.RFC3339))
if expTimeObj.Before(refreshThreshold) {
remainingSeconds := int64(time.Until(expTimeObj).Seconds())
t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh",
remainingSeconds, t.refreshGracePeriod)
if session.GetRefreshToken() != "" {
return true, true, false
}
t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.")
return true, false, false
}
t.logger.Debugf("Token is valid and not nearing expiration (expires in %d seconds, outside %s grace period)",
int64(time.Until(expTimeObj).Seconds()), t.refreshGracePeriod)
return true, false, false
}
// startTokenCleanup starts background cleanup goroutines for cache maintenance.
// It runs periodic cleanup of token cache, JWK cache, and session chunks.
// Includes panic recovery to ensure stability.
+5 -535
View File
@@ -3,7 +3,9 @@ package traefikoidc
import (
"bytes"
"compress/gzip"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
@@ -885,10 +887,9 @@ func TestDetectTokenTypeCaching(t *testing.T) {
},
}
token := "test-token-for-caching-with-enough-characters-for-key"
cacheKey := token
if len(token) > 32 {
cacheKey = token[:32]
}
// The cache key is a SHA-256 hash of the full token (collision-resistant).
sum := sha256.Sum256([]byte(token))
cacheKey := hex.EncodeToString(sum[:])
result := tr.detectTokenType(jwt, token)
if !result {
@@ -911,521 +912,6 @@ func TestDetectTokenTypeCaching(t *testing.T) {
}
}
// =============================================================================
// TOKEN VALIDATOR TESTS
// =============================================================================
func TestNewTokenValidator(t *testing.T) {
validator := NewTokenValidator(nil)
if validator == nil {
t.Fatal("Expected non-nil token validator")
}
if validator.logger == nil {
t.Error("Expected logger to be initialized")
}
}
func TestNewTokenValidatorWithLogger(t *testing.T) {
logger := GetSingletonNoOpLogger()
validator := NewTokenValidator(logger)
if validator == nil {
t.Fatal("Expected non-nil token validator")
}
if validator.logger != logger {
t.Error("Expected provided logger to be used")
}
}
func TestValidateTokenEmpty(t *testing.T) {
validator := NewTokenValidator(nil)
result := validator.ValidateToken("", false)
if result.Valid {
t.Error("Expected invalid result for empty token")
}
if result.Error == nil {
t.Error("Expected error for empty token")
}
if !strings.Contains(result.Error.Error(), "empty") {
t.Errorf("Expected 'empty' in error, got: %v", result.Error)
}
}
func TestValidateTokenRequireJWT(t *testing.T) {
validator := NewTokenValidator(nil)
result := validator.ValidateToken("opaque_token_value_here", true)
if result.Valid {
t.Error("Expected invalid result for opaque token when JWT required")
}
if result.Error == nil {
t.Error("Expected error when JWT required but opaque token provided")
}
}
func TestValidateJWTValidFormat(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := createTestJWTSimple(claims)
result := validator.ValidateToken(token, false)
if !result.Valid {
t.Errorf("Expected valid result, got error: %v", result.Error)
}
if result.TokenType != "JWT" {
t.Errorf("Expected TokenType 'JWT', got %s", result.TokenType)
}
if result.Claims == nil {
t.Error("Expected claims to be parsed")
}
if result.Expiry == nil {
t.Error("Expected expiry to be extracted")
}
if result.IssuedAt == nil {
t.Error("Expected issued at to be extracted")
}
}
func TestValidateJWTExpiredToken(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
}
token := createTestJWTSimple(claims)
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for expired token")
}
if result.Error == nil {
t.Error("Expected error for expired token")
}
if !strings.Contains(result.Error.Error(), "expired") {
t.Errorf("Expected 'expired' in error, got: %v", result.Error)
}
}
func TestValidateJWTFutureIssuedAt(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"exp": time.Now().Add(2 * time.Hour).Unix(),
"iat": time.Now().Add(10 * time.Minute).Unix(),
}
token := createTestJWTSimple(claims)
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for future iat")
}
if result.Error == nil {
t.Error("Expected error for future iat")
}
if !strings.Contains(result.Error.Error(), "future") {
t.Errorf("Expected 'future' in error, got: %v", result.Error)
}
}
func TestValidateJWTNotBeforeClaim(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"exp": time.Now().Add(2 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"nbf": time.Now().Add(1 * time.Hour).Unix(),
}
token := createTestJWTSimple(claims)
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for nbf in future")
}
if result.Error == nil {
t.Error("Expected error for nbf in future")
}
if !strings.Contains(result.Error.Error(), "not yet valid") {
t.Errorf("Expected 'not yet valid' in error, got: %v", result.Error)
}
}
func TestValidateJWTInvalidFormat(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
token string
}{
{"single part", "eyJhbGciOiJIUzI1NiJ9"},
{"two parts", "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0In0"},
{"four parts", "part1.part2.part3.part4"},
{"empty part", "eyJhbGciOiJIUzI1NiJ9..signature"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ValidateToken(tt.token, true)
if result.Valid {
t.Error("Expected invalid result for malformed JWT")
}
if result.Error == nil {
t.Error("Expected error for malformed JWT")
}
})
}
}
func TestValidateOpaqueTokenValid(t *testing.T) {
validator := NewTokenValidator(nil)
token := "sk_live_abcdef123456GHIJKL789"
result := validator.ValidateToken(token, false)
if !result.Valid {
t.Errorf("Expected valid result, got error: %v", result.Error)
}
if result.TokenType != "Opaque" {
t.Errorf("Expected TokenType 'Opaque', got %s", result.TokenType)
}
}
func TestValidateOpaqueTokenTooShort(t *testing.T) {
validator := NewTokenValidator(nil)
token := "short"
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for short token")
}
if result.Error == nil {
t.Error("Expected error for short token")
}
if !strings.Contains(result.Error.Error(), "too short") {
t.Errorf("Expected 'too short' in error, got: %v", result.Error)
}
}
func TestValidateOpaqueTokenWithSpaces(t *testing.T) {
validator := NewTokenValidator(nil)
token := "this token has spaces in it"
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for token with spaces")
}
if result.Error == nil {
t.Error("Expected error for token with spaces")
}
if !strings.Contains(result.Error.Error(), "spaces") {
t.Errorf("Expected 'spaces' in error, got: %v", result.Error)
}
}
func TestValidateOpaqueTokenControlCharacters(t *testing.T) {
validator := NewTokenValidator(nil)
token := "token_with\x00control_char"
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for token with control characters")
}
if result.Error == nil {
t.Error("Expected error for token with control characters")
}
if !strings.Contains(result.Error.Error(), "control character") {
t.Errorf("Expected 'control character' in error, got: %v", result.Error)
}
}
func TestValidateOpaqueTokenInsufficientEntropy(t *testing.T) {
validator := NewTokenValidator(nil)
token := "aaaaaabbbbbbccccccdddd"
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for low entropy token")
}
if result.Error == nil {
t.Error("Expected error for low entropy token")
}
if !strings.Contains(result.Error.Error(), "entropy") {
t.Errorf("Expected 'entropy' in error, got: %v", result.Error)
}
}
func TestIsValidBase64URL(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
input string
expected bool
}{
{"valid uppercase", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", true},
{"valid lowercase", "abcdefghijklmnopqrstuvwxyz", true},
{"valid numbers", "0123456789", true},
{"valid dash", "abc-def", true},
{"valid underscore", "abc_def", true},
{"valid equals", "abc=", true},
{"invalid at sign", "abc@def", false},
{"invalid space", "abc def", false},
{"invalid plus", "abc+def", false},
{"invalid slash", "abc/def", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.isValidBase64URL(tt.input)
if result != tt.expected {
t.Errorf("Expected %v for %s, got %v", tt.expected, tt.input, result)
}
})
}
}
func TestExtractTime(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
claim interface{}
name string
expected bool
}{
{name: "float64", claim: float64(1609459200), expected: true},
{name: "int64", claim: int64(1609459200), expected: true},
{name: "int", claim: int(1609459200), expected: true},
{name: "string", claim: "not a timestamp", expected: false},
{name: "nil", claim: nil, expected: false},
{name: "map", claim: map[string]interface{}{}, expected: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.extractTime(tt.claim)
if tt.expected && result == nil {
t.Error("Expected non-nil time")
}
if !tt.expected && result != nil {
t.Error("Expected nil time")
}
})
}
}
func TestValidateTokenSize(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
token string
maxSize int
expectError bool
}{
{"within limit", "short_token", 20, false},
{"at limit", "exactly_twenty_c", 16, false},
{"exceeds limit", "this_token_is_too_long", 10, true},
{"empty token", "", 10, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateTokenSize(tt.token, tt.maxSize)
if tt.expectError && err == nil {
t.Error("Expected error for oversized token")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if err != nil && !strings.Contains(err.Error(), "exceeds") {
t.Errorf("Expected 'exceeds' in error, got: %v", err)
}
})
}
}
func TestExtractClaims(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"email": "user@example.com",
"exp": float64(1609459200),
}
token := createTestJWTSimple(claims)
extracted, err := validator.ExtractClaims(token)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if extracted == nil {
t.Fatal("Expected non-nil claims")
}
if extracted["sub"] != "user123" {
t.Errorf("Expected sub 'user123', got %v", extracted["sub"])
}
if extracted["email"] != "user@example.com" {
t.Errorf("Expected email 'user@example.com', got %v", extracted["email"])
}
}
func TestExtractClaimsInvalidFormat(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
token string
}{
{"single part", "onlyonepart"},
{"two parts", "two.parts"},
{"four parts", "one.two.three.four"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := validator.ExtractClaims(tt.token)
if err == nil {
t.Error("Expected error for invalid format")
}
if !strings.Contains(err.Error(), "invalid JWT format") {
t.Errorf("Expected 'invalid JWT format' in error, got: %v", err)
}
})
}
}
func TestCompareTokensEqual(t *testing.T) {
validator := NewTokenValidator(nil)
token1 := "secret_token_12345"
token2 := "secret_token_12345"
if !validator.CompareTokens(token1, token2) {
t.Error("Expected tokens to be equal")
}
}
func TestCompareTokensDifferent(t *testing.T) {
validator := NewTokenValidator(nil)
token1 := "secret_token_12345"
token2 := "secret_token_54321"
if validator.CompareTokens(token1, token2) {
t.Error("Expected tokens to be different")
}
}
func TestCompareTokensDifferentLength(t *testing.T) {
validator := NewTokenValidator(nil)
token1 := "short"
token2 := "much_longer_token"
if validator.CompareTokens(token1, token2) {
t.Error("Expected tokens to be different (different lengths)")
}
}
func TestCompareTokensEmpty(t *testing.T) {
validator := NewTokenValidator(nil)
token1 := ""
token2 := ""
if !validator.CompareTokens(token1, token2) {
t.Error("Expected empty tokens to be equal")
}
}
func TestValidateTokenMaliciousPayloads(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
token string
}{
{"sql injection attempt", "'; DROP TABLE users; --"},
{"xss attempt", "<script>alert('xss')</script>"},
{"path traversal", "../../../etc/passwd"},
{"null bytes", "token\x00with\x00nulls"},
{"unicode exploit", "token\u0000\u0001\u0002"},
{"extremely long", strings.Repeat("a", 100000)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ValidateToken(tt.token, false)
if result.Valid {
if result.Claims != nil {
t.Logf("Token considered valid: %s", tt.name)
}
} else {
if result.Error == nil {
t.Error("Expected error for malicious payload")
}
}
})
}
}
// =============================================================================
// CONSOLIDATED TOKEN TESTS
// =============================================================================
@@ -2098,19 +1584,3 @@ func createTokenOfSize(baseToken string, targetSize int) string {
return baseToken
}
func createTestJWTSimple(claims map[string]interface{}) string {
header := map[string]interface{}{
"alg": "HS256",
"typ": "JWT",
}
headerJSON, _ := json.Marshal(header)
claimsJSON, _ := json.Marshal(claims)
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
signature := base64.RawURLEncoding.EncodeToString([]byte("fake_signature"))
return headerB64 + "." + claimsB64 + "." + signature
}
+299
View File
@@ -0,0 +1,299 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik.
// This file contains requestState-aware variants of the token validation
// functions. They read session field values from the captured snapshot in
// *requestState instead of calling session.GetX(), eliminating ~21 RLock
// acquisitions on sd.sessionMutex per request through the validation path
// (validateStandardTokens reads 17, validateAzureTokens reads 10,
// validateTokenExpiry reads 4 — and many are the SAME field). Under Yaegi
// each RLock costs ~1-5ms of interpreter dispatch.
//
// The non-RS variants are retained for paths that don't have a captured
// snapshot (tests that drive the validators directly, the Azure/Google path
// when reached without rs threading, etc).
package traefikoidc
import (
"encoding/base64"
"encoding/json"
"strings"
"time"
)
// isUserAuthenticatedRS is the requestState-aware variant of
// isUserAuthenticated. Dispatches to the right per-provider validator based
// on the configured provider, all of which read from rs instead of session.
func (t *TraefikOidc) isUserAuthenticatedRS(rs *requestState) (bool, bool, bool) {
if t.isAzureProvider() {
return t.validateAzureTokensRS(rs)
} else if t.isGoogleProvider() {
return t.validateGoogleTokensRS(rs)
}
return t.validateStandardTokensRS(rs)
}
// validateGoogleTokensRS handles Google-specific token validation. Currently
// delegates to standard token validation; retained as a hook for any future
// Google-specific behavior (matches the v1.0.20 layout of the non-RS variant).
func (t *TraefikOidc) validateGoogleTokensRS(rs *requestState) (bool, bool, bool) {
return t.validateStandardTokensRS(rs)
}
// validateTokenExpiryRS is the requestState-aware variant of validateTokenExpiry.
// Reads rs.refreshToken instead of session.GetRefreshToken() (4 RLocks avoided).
func (t *TraefikOidc) validateTokenExpiryRS(rs *requestState, token string) (bool, bool, bool) {
cachedClaims, found := t.tokenCache.Get(token)
if !found {
t.logger.Debug("Claims not found in cache after successful token verification")
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
expClaim, ok := cachedClaims["exp"].(float64)
if !ok {
t.logger.Error("Failed to get expiration time ('exp' claim) from verified token")
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
expTimeObj := time.Unix(int64(expClaim), 0)
nowObj := time.Now()
if expTimeObj.Before(nowObj) {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
refreshThreshold := nowObj.Add(t.refreshGracePeriod)
if expTimeObj.Before(refreshThreshold) {
if rs.refreshToken != "" {
return true, true, false
}
return true, false, false
}
return true, false, false
}
// validateStandardTokensRS is the requestState-aware variant of
// validateStandardTokens. Replaces all session.GetX() calls (17 of them in
// the non-RS variant, dominated by GetRefreshToken called 11 times) with
// rs field reads. Same control flow.
//
//nolint:gocognit,gocyclo // Mirrors validateStandardTokens complexity by design.
func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bool) {
if !rs.authenticated {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, false
}
if rs.accessToken == "" {
if rs.refreshToken != "" {
// ID-token grace-period check (only when accessToken is absent).
if rs.idToken != "" {
parts := strings.Split(rs.idToken, ".")
if len(parts) == 3 {
if claimsData, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
var claims map[string]interface{}
if err := json.Unmarshal(claimsData, &claims); err == nil {
if expClaim, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(expClaim), 0)
if time.Now().After(expTime) {
expiredDuration := time.Since(expTime)
if expiredDuration > t.refreshGracePeriod {
return false, false, true
}
}
}
}
}
}
}
return false, true, false
}
return false, false, true
}
dotCount := strings.Count(rs.accessToken, ".")
isOpaqueToken := dotCount != 2
if isOpaqueToken {
if t.allowOpaqueTokens {
if err := t.validateOpaqueToken(rs.accessToken); err != nil {
errMsg := err.Error()
isTokenInvalid := strings.Contains(errMsg, "token is not active") ||
strings.Contains(errMsg, "revoked") ||
strings.Contains(errMsg, "token has expired")
if isTokenInvalid {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
if t.requireTokenIntrospection {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
// Transient introspection error: fall through to ID-token validation.
} else {
// Introspection succeeded.
if rs.idToken != "" {
return t.validateTokenExpiryRS(rs, rs.idToken)
}
// No ID token to corroborate an access token we cannot verify
// (Azure nonce-bearing Graph access tokens carry a proprietary,
// client-unverifiable signature). Do NOT authenticate on an
// unverified token: refresh if a refresh token is available,
// otherwise force re-authentication.
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
}
// Fall back to ID-token validation when opaque + no successful introspection.
if rs.idToken == "" {
if rs.refreshToken != "" {
return false, true, false
}
// Opaque access token, no ID token to corroborate it, and
// introspection was unavailable/disabled/errored (e.g.
// circuit-breaker open). There is nothing left to verify the token
// against, so fail closed and force re-authentication rather than
// trusting an unverified opaque token.
return false, false, true
}
if err := t.verifyToken(rs.idToken); err != nil {
if strings.Contains(err.Error(), "token has expired") {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiryRS(rs, rs.idToken)
}
// JWT access token present.
accessTokenValid := false
if err := t.verifyToken(rs.accessToken); err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "invalid audience") || strings.Contains(errMsg, "audience") {
if t.strictAudienceValidation {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
// Fall through to ID-token validation.
}
} else {
accessTokenValid = true
}
if rs.idToken == "" {
if accessTokenValid {
return t.validateTokenExpiryRS(rs, rs.accessToken)
}
if rs.refreshToken != "" {
return true, true, false
}
return true, false, false
}
if err := t.verifyToken(rs.idToken); err != nil {
if strings.Contains(err.Error(), "token has expired") {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
if accessTokenValid {
return t.validateTokenExpiryRS(rs, rs.accessToken)
}
return t.validateTokenExpiryRS(rs, rs.idToken)
}
// validateAzureTokensRS is the requestState-aware variant of validateAzureTokens.
// Eliminates 10 session.GetX() RLocks per Azure-path request.
func (t *TraefikOidc) validateAzureTokensRS(rs *requestState) (bool, bool, bool) {
if !rs.authenticated {
if rs.refreshToken != "" {
return false, true, false
}
return false, true, false
}
if rs.accessToken != "" {
if strings.Count(rs.accessToken, ".") == 2 {
if t.isUnverifiableAzureAccessToken(rs.accessToken) {
if rs.idToken != "" {
if err := t.verifyToken(rs.idToken); err != nil {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiryRS(rs, rs.idToken)
}
return true, false, false
}
if err := t.verifyToken(rs.accessToken); err != nil {
if rs.idToken != "" {
if err := t.verifyToken(rs.idToken); err != nil {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiryRS(rs, rs.idToken)
}
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiryRS(rs, rs.accessToken)
}
// Opaque access token.
if rs.idToken != "" {
return t.validateTokenExpiryRS(rs, rs.idToken)
}
return true, false, false
}
if rs.idToken != "" {
if err := t.verifyToken(rs.idToken); err != nil {
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiryRS(rs, rs.idToken)
}
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
}
-263
View File
@@ -1,263 +0,0 @@
package traefikoidc
import (
"bytes"
"encoding/base64"
"fmt"
"strings"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/pool"
)
// TokenValidator provides unified token validation functionality
type TokenValidator struct {
logger *Logger
}
// NewTokenValidator creates a new token validator
func NewTokenValidator(logger *Logger) *TokenValidator {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &TokenValidator{
logger: logger,
}
}
// TokenValidationResult contains the result of token validation
type TokenValidationResult struct {
Error error
Claims map[string]interface{}
Expiry *time.Time
IssuedAt *time.Time
TokenType string
Valid bool
}
// ValidateToken performs comprehensive token validation
func (v *TokenValidator) ValidateToken(token string, requireJWT bool) TokenValidationResult {
result := TokenValidationResult{}
// Basic validation
if token == "" {
result.Error = fmt.Errorf("token is empty")
return result
}
// Check if it's a JWT or opaque token
dotCount := strings.Count(token, ".")
isJWT := dotCount == 2
if requireJWT && !isJWT {
result.Error = fmt.Errorf("token is not a valid JWT (found %d dots, expected 2)", dotCount)
return result
}
if isJWT {
return v.validateJWT(token)
} else {
return v.validateOpaqueToken(token)
}
}
// validateJWT validates a JWT token
func (v *TokenValidator) validateJWT(token string) TokenValidationResult {
result := TokenValidationResult{
TokenType: "JWT",
}
parts := strings.Split(token, ".")
if len(parts) != 3 {
result.Error = fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
return result
}
// Validate each part
for i, part := range parts {
if part == "" {
result.Error = fmt.Errorf("JWT part %d is empty", i)
return result
}
// Check for valid base64url characters
if !v.isValidBase64URL(part) {
result.Error = fmt.Errorf("JWT part %d contains invalid base64url characters", i)
return result
}
}
// Decode and parse claims
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
result.Error = fmt.Errorf("failed to decode JWT payload: %w", err)
return result
}
var claims map[string]interface{}
pm := pool.Get()
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
defer pm.PutJSONDecoder(decoder)
if err := decoder.Decode(&claims); err != nil {
result.Error = fmt.Errorf("failed to parse JWT claims: %w", err)
return result
}
result.Claims = claims
// Extract standard claims
if exp, ok := claims["exp"]; ok {
expTime := v.extractTime(exp)
if expTime != nil {
result.Expiry = expTime
// Check if expired
if time.Now().After(*expTime) {
result.Error = fmt.Errorf("token is expired (expired at %v)", expTime.Format(time.RFC3339))
return result
}
}
}
if iat, ok := claims["iat"]; ok {
iatTime := v.extractTime(iat)
if iatTime != nil {
result.IssuedAt = iatTime
// Check if issued in future
if iatTime.After(time.Now().Add(5 * time.Minute)) {
result.Error = fmt.Errorf("token issued in future (iat: %v)", iatTime.Format(time.RFC3339))
return result
}
}
}
// Check nbf (not before)
if nbf, ok := claims["nbf"]; ok {
nbfTime := v.extractTime(nbf)
if nbfTime != nil && time.Now().Before(*nbfTime) {
result.Error = fmt.Errorf("token not yet valid (nbf: %v)", nbfTime.Format(time.RFC3339))
return result
}
}
result.Valid = true
return result
}
// validateOpaqueToken validates an opaque token
func (v *TokenValidator) validateOpaqueToken(token string) TokenValidationResult {
result := TokenValidationResult{
TokenType: "Opaque",
}
// Check minimum length
if len(token) < 20 {
result.Error = fmt.Errorf("opaque token too short (length: %d)", len(token))
return result
}
// Check for spaces
if strings.Contains(token, " ") {
result.Error = fmt.Errorf("opaque token contains spaces")
return result
}
// Check for control characters
for i, char := range token {
if char < 32 || char == 127 {
result.Error = fmt.Errorf("opaque token contains control character at position %d", i)
return result
}
}
// Check entropy
if len(token) >= 20 {
uniqueChars := make(map[rune]bool)
for _, char := range token {
uniqueChars[char] = true
}
if len(uniqueChars) < 8 {
result.Error = fmt.Errorf("opaque token has insufficient entropy (unique chars: %d)", len(uniqueChars))
return result
}
}
result.Valid = true
return result
}
// isValidBase64URL checks if a string contains only valid base64url characters
func (v *TokenValidator) isValidBase64URL(s string) bool {
for _, char := range s {
if !((char >= 'A' && char <= 'Z') ||
(char >= 'a' && char <= 'z') ||
(char >= '0' && char <= '9') ||
char == '-' || char == '_' || char == '=') {
return false
}
}
return true
}
// extractTime extracts a time.Time from various claim formats
func (v *TokenValidator) extractTime(claim interface{}) *time.Time {
var timestamp int64
switch val := claim.(type) {
case float64:
timestamp = int64(val)
case int64:
timestamp = val
case int:
timestamp = int64(val)
default:
return nil
}
t := time.Unix(timestamp, 0)
return &t
}
// ValidateTokenSize checks if token size is within acceptable limits
func (v *TokenValidator) ValidateTokenSize(token string, maxSize int) error {
if len(token) > maxSize {
return fmt.Errorf("token exceeds maximum size (size: %d, max: %d)", len(token), maxSize)
}
return nil
}
// ExtractClaims extracts claims from a JWT without full validation
func (v *TokenValidator) ExtractClaims(token string) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %w", err)
}
var claims map[string]interface{}
pm := pool.Get()
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
defer pm.PutJSONDecoder(decoder)
if err := decoder.Decode(&claims); err != nil {
return nil, fmt.Errorf("failed to parse claims: %w", err)
}
return claims, nil
}
// CompareTokens safely compares two tokens for equality
func (v *TokenValidator) CompareTokens(token1, token2 string) bool {
if len(token1) != len(token2) {
return false
}
// Use constant-time comparison to prevent timing attacks
var result byte
for i := 0; i < len(token1); i++ {
result |= token1[i] ^ token2[i]
}
return result == 0
}
+41 -5
View File
@@ -5,6 +5,7 @@ import (
"context"
"net/http"
"sync"
"sync/atomic"
"text/template"
"time"
@@ -64,8 +65,46 @@ type ProviderMetadata struct {
// It integrates with various OIDC providers, manages sessions, caches tokens, and handles
// the complete authentication flow. It's designed to work seamlessly with Traefik's
// plugin system and provides flexible configuration options.
// MetadataSnapshot is an immutable bundle of provider-metadata URLs that the
// plugin needs on the hot request path. Published atomically via
// TraefikOidc.metadataSnapshot; readers do exactly one atomic.Value.Load to
// access all fields. Replaces 3 per-request metadataMu.RLock acquisitions
// in middleware.ServeHTTP + token_manager paths, each of which paid
// 1-5ms of Yaegi-dispatch overhead.
//
// The fields are a strict subset of the metadataMu-guarded TraefikOidc
// fields; the legacy fields are still written under metadataMu for
// less-frequent code paths that have not been migrated.
type MetadataSnapshot struct {
IssuerURL string
JWKSURL string
TokenURL string
AuthURL string
RevocationURL string
EndSessionURL string
IntrospectionURL string
RegistrationURL string
}
type TraefikOidc struct {
lastMetadataRetryTime time.Time
// metadataSnapshot atomically publishes the read-mostly URL bundle.
// Hot-path readers (middleware.ServeHTTP, token verification) load it
// directly; less-frequent paths still acquire metadataMu.RLock and
// read the individual fields below.
metadataSnapshot atomic.Value
// lastMetadataRetryNano is the UnixNano timestamp of the last metadata
// recovery attempt. Stored atomically so the hot ServeHTTP path can
// throttle retries without acquiring metadataRetryMutex on every request.
lastMetadataRetryNano int64
// firstRequestStarted is 0 until the very first non-health request fires
// the background-task bootstrap; then it flips to 1 via CAS. Replaces the
// firstRequestMutex + firstRequestReceived combo which previously took
// a write lock on every non-health request forever.
firstRequestStarted int32
// metadataRefreshStartedAtomic is the CAS-only variant of the old
// metadataRefreshStarted bool. Both flags live under the same atomic so
// concurrent first-request goroutines race exactly once.
metadataRefreshStartedAtomic int32
jwkCache JWKCacheInterface
jwtVerifier JWTVerifier
ctx context.Context
@@ -126,21 +165,18 @@ type TraefikOidc struct {
frontchannelLogoutPath string
scopesSupported []string
scopes []string
extraAuthParams map[string]string
refreshGracePeriod time.Duration
maxRefreshTokenAge time.Duration
metadataMu sync.RWMutex
shutdownOnce sync.Once
metadataRetryMutex sync.Mutex
firstRequestMutex sync.Mutex
sessionInvalidationCache CacheInterface
refreshResultCache CacheInterface
minimalHeaders bool
stripAuthCookies bool
enableBackchannelLogout bool
enableFrontchannelLogout bool
firstRequestReceived bool
requireTokenIntrospection bool
metadataRefreshStarted bool
allowPrivateIPAddresses bool
disableReplayDetection bool
allowOpaqueTokens bool
+55 -10
View File
@@ -396,8 +396,16 @@ func (c *UniversalCache) getLocal(key string) (interface{}, bool) {
return value, true
}
c.mu.RUnlock()
// Expired — fall through to the write-locked slow path below to
// remove the entry under exclusive access.
// Expired — return miss immediately. The periodic cleanup goroutine
// will evict the stale entry. NEVER fall through to the write-locked
// slow path for Token/JWK/Session caches: under Yaegi the write Lock
// at line 403 costs 10-100ms per acquisition, and Go's RWMutex
// writer-priority semantics block ALL new RLock callers while a Lock
// is pending. A single expired-token event turns every concurrent
// request from read-parallel into write-serialized — the exact
// convoy that produced the 737-goroutine pileup at 0x400275a608.
atomic.AddInt64(&c.misses, 1)
return nil, false
}
c.mu.Lock()
@@ -595,15 +603,28 @@ func (c *UniversalCache) removeItem(key string, item *CacheItem) {
// evictOldest evicts the oldest item from the cache (must be called with lock held)
func (c *UniversalCache) evictOldest() {
if elem := c.lruList.Back(); elem != nil {
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
if item, exists := c.items[key]; exists {
c.removeItem(key, item)
atomic.AddInt64(&c.evictions, 1)
if c.logger.IsDebug() {
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
}
elem := c.lruList.Back()
if elem == nil {
return
}
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
if item, exists := c.items[key]; exists && item.element == elem {
c.removeItem(key, item)
atomic.AddInt64(&c.evictions, 1)
if c.logger.IsDebug() {
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
}
return
}
// Defensive forward-progress guard: the back node is dangling — its key is
// absent from c.items, or c.items[key] points at a newer node (a stale
// duplicate). Drop the node directly so an eviction loop
// (`for ... && c.lruList.Len() > 0`) is guaranteed to terminate and can
// never spin holding c.mu.Lock(). With the updateLocalCache replace-in-place
// fix this branch should be unreachable, but it makes the spin impossible.
c.lruList.Remove(elem)
if c.currentSize > 0 {
c.currentSize--
}
}
@@ -936,6 +957,30 @@ func (c *UniversalCache) updateLocalCache(key string, value interface{}, ttl tim
}
now := time.Now()
// Replace an existing entry in place: update the item and move its single
// list node to the front. Without this, a repeat populate of the same key
// (the per-request Get->backend-hit path) would PushFront a duplicate node
// and overwrite c.items[key], orphaning the previous node. Orphans inflate
// currentMemory/currentSize and, once eviction deletes the key, leave a
// Back() node whose key is absent from c.items — so evictOldest() spins
// while holding c.mu.Lock(): the 100%-CPU write-lock convoy seen in pprof.
// setLocal dedups the same way; evictOldest also guards any dangling node.
if existing, exists := c.items[key]; exists {
c.currentMemory -= existing.Size
c.lruList.Remove(existing.element)
existing.Value = value
existing.Size = size
existing.ExpiresAt = now.Add(ttl)
existing.LastAccessed = now
existing.AccessCount++
existing.element = c.lruList.PushFront(key)
c.currentMemory += size
return nil
}
item := &CacheItem{
Key: key,
Value: value,
+84
View File
@@ -0,0 +1,84 @@
package traefikoidc
import (
"testing"
"time"
)
// newOrphanTestCache builds a Token-type cache with background cleanup disabled
// so the test fully controls lruList/items state.
func newOrphanTestCache(maxMem int64) *UniversalCache {
return NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken,
DefaultTTL: time.Hour,
MaxSize: 1_000_000, // large: keep the size-branch out of the way
MaxMemoryBytes: maxMem,
EnableMemoryLimit: maxMem > 0,
SkipAutoCleanup: true,
EnableAutoCleanup: false,
})
}
// TestUpdateLocalCache_NoOrphanElements is the direct red test: repeatedly
// populating the SAME key via updateLocalCache (the per-request Get->backend-hit
// path) must NOT leave dangling lruList elements. Today updateLocalCache blindly
// PushFronts + overwrites c.items[key] without removing the prior element, so the
// list grows one orphan per call while items stays at 1 entry.
func TestUpdateLocalCache_NoOrphanElements(t *testing.T) {
c := newOrphanTestCache(0) // memory limit off: isolate the orphan, no eviction
const key = "same-key"
for range 5 {
if err := c.updateLocalCache(key, "v", time.Hour); err != nil {
t.Fatalf("updateLocalCache: %v", err)
}
}
c.mu.RLock()
listLen := c.lruList.Len()
itemCount := len(c.items)
c.mu.RUnlock()
if itemCount != 1 {
t.Fatalf("items: got %d want 1", itemCount)
}
if listLen != 1 {
t.Fatalf("ORPHAN BUG: lruList.Len()=%d but items=%d (one list element per key expected)", listLen, itemCount)
}
}
// TestUpdateLocalCache_EvictionTerminates is the convoy reproducer: once orphans
// for a key exist and the memory-eviction loop runs, evictOldest() deletes the
// key from items on the first eviction, after which every remaining orphan at
// Back() has a key absent from items -> evictOldest() no-ops while lruList.Len()>0
// stays true -> infinite loop while holding c.mu.Lock(). That is the 100%-CPU
// holder + write-lock convoy observed in pprof.
func TestUpdateLocalCache_EvictionTerminates(t *testing.T) {
c := newOrphanTestCache(0) // start with memory limit OFF to accumulate orphans
const key = "same-key"
// Build 3 same-key list elements (3 orphans, items={key}).
for range 3 {
if err := c.updateLocalCache(key, "v", time.Hour); err != nil {
t.Fatalf("seed updateLocalCache: %v", err)
}
}
// Arm the trap: tiny memory limit so the next call enters the eviction loop.
c.mu.Lock()
c.config.MaxMemoryBytes = 1
c.mu.Unlock()
done := make(chan struct{})
go func() {
_ = c.updateLocalCache(key, "v", time.Hour) // triggers the eviction loop
close(done)
}()
select {
case <-done:
// fix present: loop made forward progress and returned
case <-time.After(2 * time.Second):
t.Fatal("INFINITE LOOP: eviction loop did not terminate within 2s (orphan whose key was deleted is never removed from lruList)")
}
}
+98 -1
View File
@@ -19,7 +19,7 @@ import (
// - true if the URL should be excluded from authentication, false otherwise.
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
if pathExcluded(currentRequest, excludedURL) {
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
@@ -27,6 +27,31 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
return false
}
// pathExcluded reports whether requestPath is covered by an excluded prefix at a
// natural boundary: an exact match, a sub-path ("/public" → "/public/x"), or a
// file extension ("/favicon" → "/favicon.ico"). It deliberately does NOT match
// an unrelated sibling such as "/publicsecret", so a configured exclusion can no
// longer be widened into an authentication bypass on a different resource.
func pathExcluded(requestPath, excluded string) bool {
excluded = strings.TrimRight(excluded, "/")
if excluded == "" {
// A "/" (root) exclusion only matches the root path, not everything.
return requestPath == "" || requestPath == "/"
}
if requestPath == excluded {
return true
}
if !strings.HasPrefix(requestPath, excluded) {
return false
}
switch requestPath[len(excluded)] {
case '/', '.':
return true
default:
return false
}
}
// buildAuthURL constructs the OIDC provider authorization URL.
// It builds the URL with all necessary parameters including client_id, scopes,
// PKCE parameters, and provider-specific parameters for Google and Azure.
@@ -146,6 +171,21 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
t.logger.Debugf("TraefikOidc.buildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
}
// Apply operator-configured extra authorization parameters (e.g.
// screen_hint, login_hint, ui_locales, prompt). These are added last but
// can never override parameters the plugin itself manages (client_id,
// state, nonce, redirect_uri, code_challenge, scope, response_type, ...):
// a key already present in params is left untouched, so this cannot
// weaken security-critical parameters.
for key, value := range t.extraAuthParams {
if params.Get(key) == "" {
params.Set(key, value)
t.logger.Debugf("TraefikOidc.buildAuthURL: Added extra auth param %s", key)
} else {
t.logger.Debugf("TraefikOidc.buildAuthURL: Skipped extra auth param %s (already set by plugin)", key)
}
}
// Read authURL with RLock
t.metadataMu.RLock()
authURL := t.authURL
@@ -273,6 +313,63 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error {
return nil
}
// validateDiscoveredEndpoint validates an endpoint URL obtained from the
// provider's OIDC/OAuth2 discovery document before the plugin issues any
// outbound request to it. A discovery document is attacker-influenced if the
// provider is malicious or its TLS is broken, so an unvalidated endpoint is an
// SSRF vector (e.g. jwks_uri or introspection_endpoint pointed at the cloud
// metadata service 169.254.169.254 or an internal host).
//
// Empty endpoints are allowed (they are optional). Link-local (which covers the
// 169.254.0.0/16 metadata range), multicast and unspecified addresses are
// always rejected. Private addresses are rejected unless allowPrivateIPAddresses
// is set. Loopback is rejected unless allowLoopback is true — which the caller
// sets only when the operator-configured providerURL is itself loopback (local
// development, in-cluster sidecars, tests), so production deployments pointed at
// a real provider still block loopback SSRF.
func (t *TraefikOidc) validateDiscoveredEndpoint(urlStr string, allowLoopback bool) error {
if urlStr == "" {
return nil
}
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format: %w", err)
}
if u.Scheme != "https" && u.Scheme != "http" {
return fmt.Errorf("disallowed URL scheme: %q", u.Scheme)
}
if u.Host == "" {
return fmt.Errorf("missing host in URL")
}
if ip := net.ParseIP(u.Hostname()); ip != nil {
switch {
case ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsMulticast() || ip.IsUnspecified():
return fmt.Errorf("endpoint host is a blocked address: %s", ip)
case ip.IsLoopback() && !allowLoopback:
return fmt.Errorf("endpoint host is a loopback address: %s", ip)
case ip.IsPrivate() && !t.allowPrivateIPAddresses:
return fmt.Errorf("endpoint host is a private address: %s", ip)
}
}
if strings.Contains(u.Path, "..") {
return fmt.Errorf("path traversal detected in URL path")
}
return nil
}
// sameHost reports whether two URLs share the same host:port (case-insensitive).
// Used to pin the credential-bearing introspection endpoint to the operator-
// configured provider so a poisoned discovery document cannot redirect the
// client secret to an attacker-controlled host.
func sameHost(a, b string) bool {
ua, erra := url.Parse(a)
ub, errb := url.Parse(b)
if erra != nil || errb != nil || ua.Host == "" || ub.Host == "" {
return false
}
return strings.EqualFold(ua.Host, ub.Host)
}
// validateHost validates a hostname or IP address for security.
// It prevents access to localhost, private networks, and known metadata endpoints.
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
+51
View File
@@ -554,3 +554,54 @@ func TestForceHTTPSIntegration(t *testing.T) {
"should use https from X-Forwarded-Proto when forceHTTPS is false")
})
}
// TestBuildAuthURLExtraAuthParams verifies operator-configured extra
// authorization parameters are appended to the authorization URL, and that
// they can never override parameters the plugin itself manages.
func TestBuildAuthURLExtraAuthParams(t *testing.T) {
t.Run("extra params are added (e.g. screen_hint=signup)", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.extraAuthParams = map[string]string{
"screen_hint": "signup",
"ui_locales": "en",
}
authURL := middleware.buildAuthURL(
"https://app.com/callback", "state123", "nonce456", "",
)
assert.Contains(t, authURL, "screen_hint=signup")
assert.Contains(t, authURL, "ui_locales=en")
})
t.Run("nil/empty extraAuthParams is a no-op", func(t *testing.T) {
middleware := createMinimalMiddleware()
// extraAuthParams left nil
authURL := middleware.buildAuthURL(
"https://app.com/callback", "state123", "nonce456", "",
)
assert.Contains(t, authURL, "client_id=test-client")
assert.NotContains(t, authURL, "screen_hint")
})
t.Run("extra params CANNOT override plugin-managed params", func(t *testing.T) {
middleware := createMinimalMiddleware()
middleware.extraAuthParams = map[string]string{
"client_id": "ATTACKER",
"state": "ATTACKER",
"redirect_uri": "https://evil.example.com",
"response_type": "token",
}
authURL := middleware.buildAuthURL(
"https://app.com/callback", "state123", "nonce456", "",
)
// Plugin-managed values must win; injected values must be absent.
assert.Contains(t, authURL, "client_id=test-client")
assert.NotContains(t, authURL, "ATTACKER")
assert.NotContains(t, authURL, "evil.example.com")
assert.Contains(t, authURL, "response_type=code")
})
}
+21 -3
View File
@@ -14,6 +14,19 @@ import (
"time"
)
// metadataSnap returns the most recently published *MetadataSnapshot, or nil
// if metadata has not yet been resolved. Single atomic.Value.Load — the hot
// ServeHTTP path uses this instead of acquiring metadataMu.RLock, which under
// Yaegi pays 1-5ms of interpreter-dispatch overhead per acquisition.
func (t *TraefikOidc) metadataSnap() *MetadataSnapshot {
v := t.metadataSnapshot.Load()
if v == nil {
return nil
}
s, _ := v.(*MetadataSnapshot)
return s
}
// safeLogDebug provides nil-safe logging for debug messages
func (t *TraefikOidc) safeLogDebug(msg string) {
if t.logger != nil {
@@ -122,7 +135,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
return false
}
domain := parts[1]
domain := strings.ToLower(parts[1])
_, domainAllowed := t.allowedUserDomains[domain]
if domainAllowed {
@@ -223,8 +236,13 @@ func (t *TraefikOidc) Close() error {
// Get resource manager for cleanup
rm := GetResourceManager()
// Stop singleton tasks related to this instance
_ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup
// singleton-token-cleanup is a process-global task shared by every plugin
// instance. Only stop it when the LAST instance is shutting down;
// otherwise one instance's teardown (e.g. a single config reload) would
// kill chunked-session/token cleanup for all surviving instances (rank 12).
if unregisterLiveInstance() <= 0 {
_ = rm.StopBackgroundTask("singleton-token-cleanup") // best effort, last instance only
}
// Stop metadata refresh task using same hash-based name as startMetadataRefresh
if t.providerURL != "" {
hash := sha256.Sum256([]byte(t.providerURL))
@@ -0,0 +1 @@
.docs
+36
View File
@@ -0,0 +1,36 @@
version: "2"
run:
timeout: 2m
linters:
default: none
enable:
- bodyclose
- errcheck
- errorlint
- gocritic
- gocyclo
- govet
- ineffassign
- misspell
- prealloc
- revive
- staticcheck
- unconvert
- unused
settings:
gocyclo:
min-complexity: 12
revive:
rules:
- name: var-naming
- name: indent-error-flow
- name: superfluous-else
- name: unused-parameter
- name: redefines-builtin-id
formatters:
enable:
- gofmt
- goimports
+42
View File
@@ -0,0 +1,42 @@
# Configuration for lukaszraczylo/semver-generator.
# Reference: https://github.com/lukaszraczylo/semver-generator
#
# Word matching is fuzzy + case-insensitive. The keywords below mirror the
# Conventional Commits prefixes used in this repo's git history. Same pattern
# as github.com/lukaszraczylo/go-telegram/.semver.yaml.
version: 1
# Respect existing v* tags as the version baseline. semver-generator finds
# the highest existing tag and bumps from there. With no tags yet, the first
# release computes from the empty base.
force:
existing: true
# Skip merge commits and machine-generated traffic that would otherwise
# spuriously bump the version.
blacklist:
- "Merge branch"
- "Merge pull request"
- "Merge remote-tracking branch"
- "go mod tidy"
wording:
patch:
- "fix"
- "chore"
- "docs"
- "test"
- "style"
- "refactor"
- "build"
- "ci"
- "perf"
minor:
- "feat"
major:
# Match only the canonical Conventional Commits trailer. The bare word
# "breaking" is too greedy under semver-generator's fuzzy match — it
# triggers on substrings inside a commit body and wrongly produces a
# major bump.
- "BREAKING CHANGE"
+122
View File
@@ -0,0 +1,122 @@
# oss-telemetry
A tiny Go client that fires one anonymous "this binary started" ping at a
central ingest endpoint. Designed to be embedded in your own open-source
projects so you can see approximate adoption and version spread without
collecting anything that could identify a user.
This is the **client library only**. The ingest endpoint, server-side
deduplication, rate limiting, and metrics are out of scope here.
## What it sends
A single HTTP `POST` per call to `Send`:
```json
{
"project": "my-tool",
"version": "1.2.3",
"ts": 1747782200
}
```
No identifiers, no IP, no machine info, no user data. The server dedupes
incoming requests; the client just fires and forgets.
## Failproof by design
- Never blocks the caller — work runs in a goroutine.
- Never panics — the goroutine recovers internally.
- Never returns errors — bad input and network failures are silently dropped.
- Never retries, never persists state, never reads back.
- 2-second hard timeout on every request.
- Zero third-party dependencies (Go stdlib only).
The endpoint is hardcoded and not overridable from consuming code, by design.
## Install
```bash
go get github.com/lukaszraczylo/oss-telemetry
```
Requires Go 1.22+.
## Usage
```go
package main
import (
"time"
telemetry "github.com/lukaszraczylo/oss-telemetry"
)
const version = "1.2.3"
func main() {
telemetry.Send("my-tool", version)
// ... your program runs ...
// Only needed for short-lived CLIs that may exit before the goroutine
// finishes its POST. Long-running services do not need this.
telemetry.Wait(2 * time.Second)
}
```
Call `Send` once at boot. Calling it more often just sends more pings; the
server deduplicates.
## Disabling telemetry
If you ship a binary that imports this library, link your users to this
section (`https://github.com/lukaszraczylo/oss-telemetry#disabling-telemetry`)
so they can find the opt-out paths.
Any one of these turns it off:
| Mechanism | How |
| ---------------------------------------- | ---------------------------------------------------------------- |
| Universal opt-out | `DO_NOT_TRACK=1` |
| Library-wide opt-out | `OSS_TELEMETRY_DISABLED=1` |
| Project-specific opt-out | `<UPPER_PROJECT>_DISABLE_TELEMETRY=1` |
| Programmatic (e.g. behind a `--no-telemetry` flag) | `telemetry.Disable()` before the first `Send` |
Project-specific env var derivation: uppercase the project name and replace
`-` with `_`. For `my-tool` the variable is `MY_TOOL_DISABLE_TELEMETRY`.
Truthy values: `1`, `true`, `yes`, `on` (case-insensitive). Anything else is
treated as "not set".
## Validation rules (silently dropped if violated)
- `project`: matches `^[a-z0-9-]+$`, length 164.
- `version`: matches `^[A-Za-z0-9.+_-]+$`, length 132.
Bad input is a no-op — the library never logs, never errors, never crashes.
## API
```go
// Fire a single ping in the background. Returns immediately.
func Send(project, version string)
// Suppress all subsequent Send calls in this process. Idempotent.
func Disable()
// Block until in-flight pings complete or timeout elapses, whichever first.
// Useful for short-lived CLI processes.
func Wait(timeout time.Duration)
```
## Testing
```bash
go test -race ./...
```
## License
Pick one before publishing. None bundled.
+367
View File
@@ -0,0 +1,367 @@
// Package telemetry sends anonymous usage pings for open-source Go projects.
//
// Wire format (POST application/json):
//
// {"project":"<name>","version":"<ver>","ts":<unix-seconds>}
//
// Design contract (failproof):
// - never blocks the caller (work happens in a goroutine)
// - never panics (background goroutine recovers internally)
// - never returns errors (silently no-ops on bad input or network failure)
// - never retries, never deduplicates, never persists state — the client
// fires a single ping and forgets; the server is responsible for
// deduplication, abuse protection, and aggregation
//
// Typical usage at program startup:
//
// telemetry.Send("my-tool", "1.2.3")
//
// For short-lived CLI processes that may exit before the goroutine finishes:
//
// telemetry.Send("my-tool", "1.2.3")
// defer telemetry.Wait(2 * time.Second)
//
// Disablement (any one of these suppresses pings):
// - environment variable DO_NOT_TRACK=1
// - environment variable OSS_TELEMETRY_DISABLED=1
// - environment variable <UPPER_PROJECT>_DISABLE_TELEMETRY=1
// (project name uppercased, dashes replaced with underscores)
// - calling telemetry.Disable() at runtime
package telemetry
import (
"bytes"
"context"
"net/http"
"os"
"runtime/debug"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
defaultEndpoint = "https://oss.raczylo.com/v1/ping"
httpTimeout = 2 * time.Second
maxProjectLen = 64
maxVersionLen = 32
)
// Yaegi note: this package is consumed by the traefikoidc Traefik plugin, which
// Traefik interprets with Yaegi (it vendors and interprets dependency source).
// It therefore avoids generic stdlib types (atomic.Pointer[T], atomic.Bool) and
// range-over-int (Go 1.22), which some Traefik/Yaegi runtimes cannot interpret.
// Endpoint mutation uses a mutex-guarded string; the disabled flag uses the
// function-based sync/atomic int32 API (atomic.LoadInt32/StoreInt32).
var (
// endpointURL holds the ingest URL. Production code never mutates it; the
// setter exists only so the test suite can retarget it at httptest servers
// while goroutines started by Send are still in flight.
endpointMu sync.RWMutex
endpointURL = defaultEndpoint
disabled int32 // 0 = enabled, 1 = disabled; accessed via sync/atomic only
inflight sync.WaitGroup
client = &http.Client{Timeout: httpTimeout}
)
func currentEndpoint() string {
endpointMu.RLock()
defer endpointMu.RUnlock()
return endpointURL
}
func setEndpointURL(u string) {
endpointMu.Lock()
endpointURL = u
endpointMu.Unlock()
}
// Send fires a single anonymous telemetry ping in the background and returns
// immediately. It never blocks, never panics, and never reports errors.
// Invalid inputs, disabled state, and network failures are silently dropped.
//
// Version strings are validated against a SemVer-ish shape that mirrors the
// receiver. An optional leading "v" or "V" is accepted and stripped before
// transmission so that callers can pass either "v1.2.3" or "1.2.3"; the
// wire form is always the unprefixed canonical version.
//
// Call once at program startup. Calling repeatedly will send repeated pings;
// the server is responsible for deduplication.
func Send(project, version string) {
if atomic.LoadInt32(&disabled) != 0 {
return
}
if isDisabledByEnv(project) {
return
}
if !validProject(project) || !validVersion(version) {
return
}
canonical := normalizeVersion(version)
inflight.Add(1)
go func() {
defer inflight.Done()
defer func() { _ = recover() }()
dispatch(project, canonical)
}()
}
// SendForModule is the recommended call form for Go libraries: it resolves
// the version automatically from Go's build info for the given module path
// so consumers do not need to maintain a hand-bumped version constant in
// source. Behaviour and contract are otherwise identical to [Send].
//
// Resolution order:
//
// 1. debug.ReadBuildInfo Deps entry for modulePath (authoritative when the
// library is consumed via go.mod);
// 2. debug.ReadBuildInfo Main when the library is itself the main module
// (e.g. running its own tests or examples);
// 3. fallback parameter, used only when build info is unavailable or
// unhelpful (replace directives, detached `go run`, ldflag override).
//
// Any leading "v" reported by build info is stripped to match the canonical
// wire form. Empty / "(devel)" build versions are skipped in favour of the
// next resolution source. Typical usage:
//
// telemetry.SendForModule("my-tool", "github.com/me/my-tool", "0.0.0-dev")
func SendForModule(project, modulePath, fallback string) {
Send(project, ResolveModuleVersion(modulePath, fallback))
}
// ResolveModuleVersion implements the version resolution used by
// SendForModule. Exposed for callers that need to format the resolved
// version (e.g. logging) without firing a ping.
func ResolveModuleVersion(modulePath, fallback string) string {
if info, ok := debug.ReadBuildInfo(); ok {
for _, d := range info.Deps {
if d != nil && d.Path == modulePath && isUsableBuildVersion(d.Version) {
return strings.TrimPrefix(d.Version, "v")
}
}
if info.Main.Path == modulePath && isUsableBuildVersion(info.Main.Version) {
return strings.TrimPrefix(info.Main.Version, "v")
}
}
return fallback
}
func isUsableBuildVersion(v string) bool {
return v != "" && v != "(devel)"
}
// Disable suppresses all subsequent Send calls in this process.
// Idempotent and safe to call from any goroutine.
func Disable() {
atomic.StoreInt32(&disabled, 1)
}
// Wait blocks until all in-flight pings have completed, or until timeout
// elapses — whichever comes first. Useful for short-lived CLI processes
// that may otherwise exit before the background goroutine finishes its POST.
//
// A non-positive timeout returns immediately.
func Wait(timeout time.Duration) {
if timeout <= 0 {
return
}
done := make(chan struct{})
go func() {
inflight.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(timeout):
}
}
func dispatch(project, version string) {
body := buildPayload(project, version, time.Now().Unix())
ctx, cancel := context.WithTimeout(context.Background(), httpTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, currentEndpoint(), bytes.NewReader(body))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return
}
_ = resp.Body.Close()
}
// buildPayload writes the JSON body without encoding/json. The validators
// restrict project and version to characters that never require JSON
// escaping, so direct concatenation is safe.
func buildPayload(project, version string, ts int64) []byte {
// Wrapper text plus 20 chars for a signed int64.
const overhead = len(`{"project":"","version":"","ts":}`) + 20
buf := make([]byte, 0, len(project)+len(version)+overhead)
buf = append(buf, `{"project":"`...)
buf = append(buf, project...)
buf = append(buf, `","version":"`...)
buf = append(buf, version...)
buf = append(buf, `","ts":`...)
buf = strconv.AppendInt(buf, ts, 10)
buf = append(buf, '}')
return buf
}
func validProject(p string) bool {
n := len(p)
if n == 0 || n > maxProjectLen {
return false
}
for i := 0; i < n; i++ {
c := p[i]
switch {
case c >= 'a' && c <= 'z',
c >= '0' && c <= '9',
c == '-':
default:
return false
}
}
return true
}
// validVersion accepts SemVer-ish version strings with an optional leading
// "v"/"V" prefix. Acceptable shape (after stripping the leading v):
//
// MAJOR[.MINOR[.PATCH]] ("-"prerelease)? ("+"build)?
//
// where MAJOR/MINOR/PATCH are ASCII digit sequences and the prerelease/build
// payloads are non-empty runs of [0-9A-Za-z.-]. This intentionally mirrors
// the receiver's version regex so junk like "dev" or "git-2026-05-22" never
// leaves the client (where it would only be rejected with HTTP 400 anyway).
func validVersion(v string) bool {
n := len(v)
if n == 0 || n > maxVersionLen {
return false
}
if v[0] == 'v' || v[0] == 'V' {
v = v[1:]
}
if len(v) == 0 {
return false
}
return checkSemverShape(v)
}
// normalizeVersion strips an optional leading "v"/"V" so the on-the-wire
// version matches the form stored server-side by the version refresher cron
// (which also strips the leading v from release tags). Callers may pass
// either "v1.2.3" or "1.2.3" — only the unprefixed form is transmitted.
func normalizeVersion(v string) string {
if len(v) > 0 && (v[0] == 'v' || v[0] == 'V') {
return v[1:]
}
return v
}
func checkSemverShape(s string) bool {
i := 0
if !readDigitRun(s, &i) {
return false
}
for groups := 0; groups < 2 && i < len(s) && s[i] == '.'; groups++ {
i++
if !readDigitRun(s, &i) {
return false
}
}
if i < len(s) && s[i] == '-' {
i++
if !readIdentRun(s, &i, '+') {
return false
}
}
if i < len(s) && s[i] == '+' {
i++
if !readIdentRun(s, &i, 0) {
return false
}
}
return i == len(s)
}
func readDigitRun(s string, i *int) bool {
start := *i
for *i < len(s) && s[*i] >= '0' && s[*i] <= '9' {
*i++
}
return *i > start
}
// readIdentRun consumes [0-9A-Za-z.-] until end-of-string or until `stop`
// is hit (stop=0 disables the early-stop check). Returns false if no
// characters were consumed (i.e. empty payload).
func readIdentRun(s string, i *int, stop byte) bool {
start := *i
for *i < len(s) {
c := s[*i]
if stop != 0 && c == stop {
break
}
valid := (c >= '0' && c <= '9') ||
(c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
c == '.' || c == '-'
if !valid {
return false
}
*i++
}
return *i > start
}
func isDisabledByEnv(project string) bool {
if truthy(os.Getenv("DO_NOT_TRACK")) {
return true
}
if truthy(os.Getenv("OSS_TELEMETRY_DISABLED")) {
return true
}
if project == "" {
return false
}
key := projectEnvKey(project)
return truthy(os.Getenv(key))
}
// projectEnvKey returns "<UPPER_PROJECT>_DISABLE_TELEMETRY" using a single
// allocation rather than chained strings.ToUpper(strings.ReplaceAll(...)).
func projectEnvKey(project string) string {
const suffix = "_DISABLE_TELEMETRY"
buf := make([]byte, 0, len(project)+len(suffix))
for i := 0; i < len(project); i++ {
c := project[i]
switch {
case c == '-':
c = '_'
case c >= 'a' && c <= 'z':
c -= 'a' - 'A'
}
buf = append(buf, c)
}
buf = append(buf, suffix...)
return string(buf)
}
func truthy(s string) bool {
switch strings.ToLower(strings.TrimSpace(s)) {
case "1", "true", "yes", "on":
return true
}
return false
}
+3
View File
@@ -24,6 +24,9 @@ github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
github.com/gorilla/sessions
# github.com/lukaszraczylo/oss-telemetry v0.2.3
## explicit; go 1.22
github.com/lukaszraczylo/oss-telemetry
# github.com/pmezard/go-difflib v1.0.0
## explicit
github.com/pmezard/go-difflib/difflib
+17
View File
@@ -0,0 +1,17 @@
package traefikoidc
// devPluginVersion is the placeholder carried by source-tree / local / test
// builds. Telemetry is suppressed while the plugin still reports this sentinel,
// so only stamped release builds emit a "plugin loaded" ping.
const devPluginVersion = "0.0.0-dev"
// traefikoidcPluginVersion is the released version of this plugin. It is stamped
// at release time by ./workflow-prepare.sh (invoked by the shared go-release
// workflow before GoReleaser builds and tags), which rewrites the string below
// to the computed semver.
//
// Traefik runs this plugin under Yaegi, where the version cannot be resolved
// from build info at runtime (debug.ReadBuildInfo sees Traefik's build graph,
// not the interpreted plugin). This build-stamped constant is therefore the
// single source of truth for the version reported by anonymous usage telemetry.
const traefikoidcPluginVersion = "0.0.0-dev"
+67
View File
@@ -0,0 +1,67 @@
#!/usr/bin/env bash
#
# workflow-prepare.sh — stamp the release version into version.go at build time.
#
# The shared go-release workflow (lukaszraczylo/shared-actions go-release.yaml)
# runs this script, if present, from the repository root BEFORE GoReleaser
# builds and tags. Traefik runs this plugin under Yaegi, where the version
# cannot be resolved from build info at runtime, so the released semver must be
# baked into source here.
#
# Version source — first non-empty wins:
# $VERSION $VERSION_TAG $SEMVER $NEW_VERSION $RELEASE_VERSION
# A leading "v"/"V" is stripped.
#
# NOTE: go-release.yaml @main does not yet pass the computed version into this
# step's environment. Add it to the "Run workflow prepare script" step, e.g.:
# env:
# VERSION: ${{ needs.version.outputs.version }} # bare, no leading v
#
# The shared workflow runs this script in its test, version AND release jobs,
# but only the release job has a computed version. So a missing version is a
# no-op (leave the dev sentinel) — NOT a hard failure, otherwise the test/version
# jobs would break. A malformed version that IS provided is a hard error. Wire
# the env only on the release job's prepare step (see header note above).
set -euo pipefail
FILE="version.go"
CONST="traefikoidcPluginVersion"
VER="${VERSION:-${VERSION_TAG:-${SEMVER:-${NEW_VERSION:-${RELEASE_VERSION:-}}}}}"
VER="${VER#v}"
VER="${VER#V}"
if [ -z "$VER" ]; then
if [ "${GITHUB_ACTIONS:-}" = "true" ]; then
echo "workflow-prepare: WARNING no version provided; leaving ${FILE} at the dev placeholder. If this is the release build, set 'env: VERSION: \${{ needs.version.outputs.version }}' on the release job's prepare step — otherwise the release ships 0.0.0-dev and emits no telemetry." >&2
else
echo "workflow-prepare: no version provided; leaving dev placeholder in ${FILE} (local build)"
fi
exit 0
fi
# Accept MAJOR[.MINOR[.PATCH]] with optional -prerelease / +build (semver-ish,
# matching the oss-telemetry receiver's validator).
if ! printf '%s' "$VER" | grep -Eq '^[0-9]+(\.[0-9]+){0,2}(-[0-9A-Za-z.-]+)?(\+[0-9A-Za-z.-]+)?$'; then
echo "workflow-prepare: ERROR version '${VER}' is not semver-shaped" >&2
exit 1
fi
if [ ! -f "$FILE" ]; then
echo "workflow-prepare: ERROR ${FILE} not found (run from repository root)" >&2
exit 1
fi
# Rewrite only the value of ${CONST}, anchored on the constant name so the
# sibling devPluginVersion sentinel is left untouched.
tmp="$(mktemp)"
sed -E "s/(${CONST}[[:space:]]*=[[:space:]]*\")[^\"]*(\")/\1${VER}\2/" "$FILE" > "$tmp"
mv "$tmp" "$FILE"
if ! grep -Eq "${CONST}[[:space:]]*=[[:space:]]*\"${VER}\"" "$FILE"; then
echo "workflow-prepare: ERROR failed to stamp version into ${FILE}" >&2
exit 1
fi
command -v gofmt >/dev/null 2>&1 && gofmt -w "$FILE"
echo "workflow-prepare: stamped ${CONST} = \"${VER}\" in ${FILE}"