Compare commits

...

73 Commits

Author SHA1 Message Date
lukaszraczylo 4fa579ccf4 test(oidcgate): integration test with real middleware against mock IdP (#142) 2026-05-23 03:12:17 +01:00
lukaszraczylo 2af05701dc build(release): publish multi-arch oidcgate Docker image per release tag
- Add 'oidcgate' build entry (linux/darwin × amd64/arm64) to goreleaser.
- Add per-OS/arch tar.gz archives for the daemon binary.
- Add dockers + docker_manifests entries publishing
  ghcr.io/lukaszraczylo/oidcgate:vX.Y.Z (release tag), :vX.Y, :vX, :latest
  as multi-arch manifests (linux/amd64 + linux/arm64).
- Add cmd/oidcgate/Dockerfile (distroless static, nonroot user).
- Sign images with cosign keyless (docker_signs).
- Preserve existing source-only Traefik plugin archive via meta:true.
- Update README to advertise the published image.
2026-05-19 17:14:29 +01:00
lukaszraczylo 03a755cb53 docs(oidcgate): expand user guide and cross-link
- Add HAProxy and Envoy ext_authz_http wiring snippets.
- Add full OIDCGATE_* env-var inventory (26 fields).
- Add Security Posture section (X-Forwarded-Uri sanitisation, excludedURLs
  guardrail, callbackURL/logoutURL validation).
- Add Bearer-token (M2M) auth composition section with link to BEARER_AUTH.md.
- Add Operational Guidance section (healthz/readyz ACL, Redis for multi-replica,
  no built-in metrics, graceful shutdown deadline).
- Add Debugging section (sentinel path, silent open-redirect rejections,
  /readyz warm-up).
- Cross-link from docs/CONFIGURATION.md.
2026-05-19 16:59:15 +01:00
lukaszraczylo dc0e7e0238 fix(oidcgate): gosec G304 — clean config path + native #nosec directive
The //nolint:gosec directive only suppresses golangci-lint; the standalone
gosec GitHub Action uses its own '#nosec G304 -- reason' syntax. Use both
filepath.Clean as canonical mitigation and the native directive.
2026-05-19 16:41:57 +01:00
lukaszraczylo b2e79d8798 Merge remote-tracking branch 'origin/main' into conflict-resolve
# Conflicts:
#	docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md
#	middleware.go
#	settings.go
#	types.go
2026-05-19 16:41:34 +01:00
lukaszraczylo 52ef32ece7 fix(oidcgate): security hardening — sanitize XFU, guardrails, validations 2026-05-19 15:17:04 +01:00
lukaszraczylo 3bf7c60ef4 chore: gofmt 2026-05-19 15:00:42 +01:00
lukaszraczylo 775ca7afc3 docs(oidcgate): user-facing setup guide and nginx/Caddy/Traefik wiring 2026-05-19 14:25:38 +01:00
lukaszraczylo a1273e6883 feat(oidcgate): main entrypoint with graceful shutdown 2026-05-19 14:22:46 +01:00
lukaszraczylo 0bc0079a58 refactor(oidcgate): WriteTimeout for slowloris guard, nolint reason 2026-05-19 14:18:28 +01:00
lukaszraczylo 20294f1339 feat(oidcgate): mux wiring and http.Server with graceful shutdown 2026-05-19 14:13:13 +01:00
lukaszraczylo 43938ed8a8 feat(oidcgate): healthz and readyz endpoints 2026-05-19 14:08:53 +01:00
lukaszraczylo 46679c82eb refactor(oidcgate): simplify cloneAndRewrite, flip ?rd precedence, assert XFU passthrough 2026-05-19 14:07:44 +01:00
lukaszraczylo a46be72be5 feat(oidcgate): auth/start/callback/logout endpoint handlers 2026-05-19 13:59:20 +01:00
lukaszraczylo 91966c1bec refactor(oidcgate): idempotent Finalize; document and test 307/308 intercept 2026-05-19 13:57:15 +01:00
lukaszraczylo c465fc888b feat(oidcgate): response-writer interceptor converts 302->401 for /oauth2/auth 2026-05-19 13:50:03 +01:00
lukaszraczylo 047fea3c75 refactor(oidcgate): drop unreachable lowercase prefix; add multi-value mirror test 2026-05-19 13:48:13 +01:00
lukaszraczylo 0c092a5a22 feat(oidcgate): synthetic success handler mirrors X-* headers to response 2026-05-19 13:41:51 +01:00
lukaszraczylo 8f458b4f6e fix(oidcgate): quality fixes — rune-safe snake-upper, drop dead import, listen validation, nested-struct test 2026-05-19 13:40:24 +01:00
lukaszraczylo 17c28fd574 feat(oidcgate): YAML config loader with env-var overrides 2026-05-19 13:30:28 +01:00
lukaszraczylo 21cc2ed747 refactor(lib): match codebase metadataMu lock pattern in Ready() 2026-05-19 13:25:13 +01:00
lukaszraczylo ded90e5dc1 feat(lib): add (*TraefikOidc).Ready() metadata-discovery readiness accessor 2026-05-19 13:19:20 +01:00
lukaszraczylo 46777d0510 fix(lib): also route X-Auth-Request-Redirect through originalRequestURI helper 2026-05-19 13:14:16 +01:00
lukaszraczylo f990365cb8 feat(lib): add TrustForwardedURI to honor X-Forwarded-Uri for post-login redirect target 2026-05-19 13:07:35 +01:00
lukaszraczylo 85eb9ecd16 docs: add oidcgate implementation plan 2026-05-19 13:00:56 +01:00
lukaszraczylo 3495e70cbb docs: add oidcgate Tier 1 forward-auth daemon design 2026-05-19 12:51:41 +01:00
lukaszraczylo a548665edb feat: opt-in M2M bearer-token authentication (supersedes #93) (#140)
* docs: bearer-token auth design spec

* docs: harden bearer-auth spec with security review findings

* feat(bearer): opt-in M2M bearer-token authentication

Adds an opt-in Authorization: Bearer <jwt> path for machine-to-machine
clients. Replaces and supersedes the broken approach in PR #93
(synthetic-session that omitted user_identifier and skipped ID-token
rejection / replay-protection-semantics / kid-pinning / etc.).

Design

  Two auth entrypoints feed one shared post-auth pipeline:

    cookie path  ─┐
                  ├── forwardAuthorized(rw, req, *principal)
    bearer path  ─┘    (roles/groups, header injection, security
                        headers, cookie strip, forward)

  buildPrincipalFromSession and buildPrincipalFromBearerToken produce
  the same `principal` value type. forwardAuthorized is session-agnostic
  and runs the existing post-auth work; processAuthorizedRequest now
  wraps it with the session-specific concerns (backchannel-logout,
  dirty/Save). The cookie path's behaviour is byte-identical to before
  this PR; the existing test suite passes unmodified.

Security hardening baked into the bearer path

  - Audience MANDATORY. Startup fails when EnableBearerAuth=true and
    Audience is empty.
  - BearerIdentifierClaim defaults to "sub"; "email" is rejected at
    startup to avoid the unverified-email spoofing footgun. Cookie
    path's UserIdentifierClaim is unaffected and still defaults to
    "email".
  - ID tokens explicitly rejected via the existing detectTokenType
    helper (nonce, typ=at+jwt, token_use, scope, aud-vs-clientID
    heuristics); belt-and-braces nonce/token_use=id rejection on top.
  - alg pinned to asymmetric allowlist (RS/PS/ES 256/384/512) BEFORE
    JWKS fetch, blocking alg=none and alg=HS* probes from amplifying
    into upstream calls.
  - kid length capped at 256 bytes and charset-restricted before JWKS
    fetch, blocking pathological-kid JWKS amplification.
  - Multi-audience tokens require azp == clientID.
  - iat upper-age bound (MaxTokenAgeSeconds, default 24h) bounds clock-
    manipulation and forever-token abuse.
  - Identifier sanitization: length cap, control-char + bidi-override
    + delimiter (, ; =) rejection.
  - Per-IP failure throttle: configurable threshold/window/penalty;
    returns 429 + Retry-After. Limits offline-guessing-style attacks
    and protects the shared rate-limiter / JWKS endpoint.
  - JTI replay marking suppressed via new internal verifyOpts
    {skipReplayMarking} so the same bearer can be reused until exp;
    the blacklist Get stays active so RevokeToken still terminates a
    bearer token immediately. The existing exported VerifyToken
    interface is unchanged so all mocks continue to work.
  - Cookie wins by default when both bearer and cookie are present
    (safer against browser/extension/proxy bearer injection).
    Operator can flip via BearerOverridesCookie.
  - Authorization header stripped on forward by default; also stripped
    on excluded URLs so the token can't leak into health/metrics
    downstream logs.
  - Optional RFC 7662 introspection via existing
    requireTokenIntrospection. Introspection-endpoint failure returns
    503 (distinguishes infra from token rejection).
  - 401s use RFC 6750 WWW-Authenticate hints (toggleable). Failure
    reason is logged at debug; raw tokens are never logged.

Implementation

  - principal.go: pure-data principal type and buildPrincipalFromSession.
  - bearer_auth.go: alg/kid pin, classifier, identifier sanitization,
    multi-aud azp gate, iat age check, per-IP failure tracker,
    handleBearerRequest, buildPrincipalFromBearerToken.
  - token_manager.go: VerifyToken now wraps a new verifyTokenWithOpts
    that accepts internal-only verifyOpts. Existing callers, the
    TokenVerifier interface, and all mocks unchanged.
  - middleware.go: extracted forwardAuthorized from
    processAuthorizedRequest; wired bearer detection after init wait
    + after bypass; excluded-URL Authorization strip when bearer
    enabled.
  - settings.go: ten new config fields with defaults applied in
    CreateConfig.
  - main.go: startup validation for audience + identifier-claim
    guard; bearer failure tracker init.

Tests

  - bearer_auth_test.go: table-driven helper tests for every new
    component (parseBearerJOSEHeader, sanitizeBearerIdentifier,
    resolveBearerIdentifier, enforceMultiAudienceAzp, enforceIatAge,
    bearerFailureTracker, detectBearerToken). Integration tests
    through ServeHTTP covering happy path, ID-token rejection,
    alg=none rejection, oversized kid, multi-aud with/without azp,
    iat-too-old, bidi identifier, replay (100x reuse), 429 throttle
    trip, excluded-URL strip, roles gate, cookie-wins precedence,
    BearerOverridesCookie, oversized token, malformed JWT,
    feature-off pass-through. Startup validation for audience-
    required and email-identifier-rejected.
  - All existing tests pass unmodified (cookie-path regression).
  - go vet clean. golangci-lint clean (0 issues). Race detector
    clean on bearer tests.

Documentation

  - README.md: bearer auth section with security highlights and
    config snippet; doc link in the index.
  - .traefik.yml: commented config block exposing every bearer knob.
  - docs/CONFIGURATION.md: new subsection with full parameter table.
  - docs/BEARER_AUTH.md: threat model, hardening matrix, failure
    response table, operational guidance, known follow-ups.
  - docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md:
    design spec + security-review hardening history.

* fix(cache): redact raw cache keys in debug logs (CodeQL go/clear-text-logging)

CodeQL flagged 9 high-severity alerts (go/clear-text-logging) where the
in-memory cache and the hybrid L1+L2 backend printed `key=%s` at debug.
Cache callers (token cache, blacklist, introspection cache) pass raw
access / refresh / id tokens as cache keys, so any debug-enabled
deployment would write them to log streams.

Pre-existing issue. CodeQL started flagging it on this PR because the
new bearer-auth path adds a data-flow source (req.Header.Get("Authorization"))
that reaches the existing logging sinks via the same cache. The cookie
path had the same risk but wasn't tracked as taint by CodeQL.

Fix: hash the key (SHA-256[:8] hex) before printing. Same approach the
bearer-auth logger uses for principal identifiers (spec §13). Doesn't
change cache semantics — same key still produces the same hash, so
debug correlation across log lines is preserved without exposing the
raw value.

Touches both affected packages:
  - internal/cache/cache.go (2 sites: Set + LRU eviction)
  - internal/cache/backends/hybrid.go (12 sites: L1/L2 read/write/fallback)

New helper `redactKey` colocated with each package (unexported,
package-local) keeps the change blast radius narrow. Tests green; lint
clean.

* docs(bearer): how to obtain bearer tokens from the OIDC provider

Adds a section walking operators through the OAuth 2.0 client_credentials
flow (RFC 6749 §4.4) and the JWT bearer assertion alternative (RFC 7523),
with a worked Auth0-shape curl example, a per-provider quick reference
(Auth0, Okta, Keycloak, Entra v2, Cognito, GitLab, Google), operational
notes (token TTL, caching, JWKS rotation, revocation, scope vs audience,
secret hygiene), and a three-line validation loop.

Most common operator confusion: "I enabled the feature but tokens get
401'd" — almost always missing or wrong audience. The new section makes
the audience-matching requirement loud, with per-provider parameter
names so people don't have to dig through IdP docs.

Locations:
  - docs/BEARER_AUTH.md  — full section under "Quick start"
  - README.md            — short snippet + deep link
2026-05-18 17:35:37 +01:00
lukaszraczylo fcb21a36e6 docs: harden bearer-auth spec with security review findings 2026-05-18 16:24:52 +01:00
lukaszraczylo a6c38c0747 docs: bearer-token auth design spec 2026-05-18 15:35:12 +01:00
lukaszraczylo 8c5df82dcf fix(azure): treat Microsoft proprietary access tokens as opaque (#134) (#138)
Followup to issue #134 — two reporters returned saying that even with the
JWKS caching fix in v1.0.7/v1.0.8, every request emitted:

  ERROR: TraefikOidcPlugin: UNKNOWN token verification failed:
    signature verification failed: crypto/rsa: verification error
  ERROR: TraefikOidcPlugin: DIAGNOSTIC: Signature verification failed for
    kid=<kid>, alg=RS256: crypto/rsa: verification error

Root cause: when an Azure tenant is configured without a custom API
resource, Microsoft issues access tokens for Microsoft Graph (or Azure
Mgmt). These tokens carry a `nonce` value in the JWT *header*; the bytes
that get signed contain SHA256(nonce), while the wire token ships the
original nonce. Any standard JWS verifier rejects the signature, which is
exactly Microsoft's intent — they document the format as proprietary and
tell client apps not to validate it
(https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
"you can't validate tokens for Microsoft Graph according to these rules
due to their proprietary format").

validateAzureTokens was nonetheless attempting JWT verification on every
JWT-shaped access token, then silently falling back to the ID token when
verification failed. Auth still worked end-to-end, but every request
spammed two error log lines.

Two-layer defense:

* validateAzureTokens now detects the proprietary-nonce header before
  calling verifyToken on the access token. When detected, the token is
  treated as opaque (matching the existing branch for non-JWT tokens) and
  validation proceeds via the ID token, exactly as Microsoft prescribes.

* VerifyJWTSignatureAndClaims downgrades the DIAGNOSTIC error log to
  debug for tokens carrying the same proprietary marker, in case any
  path outside validateAzureTokens reaches it.

Authorization still hinges on a separately-verifiable ID token — the
confused-deputy guard from CWE-441 is preserved (and explicitly tested).
2026-05-11 17:31:37 +01:00
lukaszraczylo aa96e9dbee Add sponsorship
Just in case you appreciate this project, feel generous and want to sponsor my caffeine addiction.
2026-05-10 21:25:26 +01:00
lukaszraczylo 1e33bb0a4d feat(auth): support private_key_jwt and client_secret_basic (#137)
revocation endpoints, joining the existing client_secret_post default.
Both are opt-in via the new clientAuthMethod config field. Closes #135.

private_key_jwt (RFC 7523 §2.2 / OpenID Connect Core §9)
========================================================
Plugin signs a short-lived JWT with a configured private key and presents
it as client_assertion. Use when the IdP enforces short secret TTLs or
requires secretless client auth (Microsoft Entra ID / Azure AD, Okta,
Auth0, Keycloak).

New Config fields:
  clientAuthMethod          (default: client_secret_post)
  clientAssertionPrivateKey (inline PEM)
  clientAssertionKeyPath    (PEM file path; mutually exclusive)
  clientAssertionKeyID      (JWS kid header — required)
  clientAssertionAlg        (default: RS256; RS/PS/ES 256–512 supported)

PEM forms accepted: PKCS#8, PKCS#1, SEC1.
Assertion claims: iss=sub=clientID, aud=tokenURL, iat=now, exp=now+60s,
random 16-byte hex jti per request. ECDSA signatures are raw r||s per
RFC 7515 (not ASN.1).

client_secret_basic (RFC 6749 §2.3.1)
=====================================
Sends credentials in the Authorization: Basic header instead of the
body. Both halves are form-urlencoded individually before base64 — that
encoding step is required by the spec and is NOT what stdlib's
http.Request.SetBasicAuth does, so the plugin uses its own helper. The
form body omits client_id and client_secret on this path.

Wire-up
=======
Both methods are dispatched at the same two call sites:
  helpers.go:exchangeTokens — auth_code + refresh_token grants
  token_manager.go:RevokeTokenWithProvider — RFC 7009 revocation

Existing clientSecret deployments are unaffected — empty
clientAuthMethod maps to the historical client_secret_post behavior, and
clientAssertion remains nil unless the new fields are set.

Yaegi compatibility
===================
All required crypto/rsa, crypto/ecdsa, crypto/x509, encoding/pem and
crypto/sha256/384/512 symbols are exposed by the traefik/yaegi stdlib
symbol tables (RSA SignPKCS1v15 + SignPSS, ECDSA Sign,
ParsePKCS8/1PrivateKey, ParseECPrivateKey).

Tests (16 new)
==============
Algorithm-family coverage:
  TestIssue135_SignerRSAFamily — RS256/384/512 + PS256/384/512
  TestIssue135_SignerECDSAFamily — ES256/384/512, raw r||s shape
  TestIssue135_SignerRejectsAlgKeyMismatch
  TestIssue135_SignerJTIUniqueness — 50 sigs, all jti distinct
  TestIssue135_SignerPEMVariants — PKCS#8, PKCS#1, SEC1

Config validation:
  TestIssue135_ConfigValidation — full Validate() matrix
  TestIssue135_ConfigKeyPathLoadsFile

Wire-up:
  TestIssue135_AuthCodeExchangeUsesAssertion
  TestIssue135_RefreshTokenUsesAssertion
  TestIssue135_BackcompatClientSecretPath
  TestIssue135_RevocationUsesAssertion
  TestIssue135_BuildSignerFromInlineConfig
  TestIssue135_BuildSignerDefaultsToRS256
  TestIssue135_ClientSecretBasicAuth — Authorization header, no body creds
  TestIssue135_ClientSecretBasicURLEncodesReservedChars — :, +, /, @, =, &
  TestIssue135_ClientSecretBasicRevocation — revocation parity

Documentation
=============
  README.md — required-row note + 5 optional rows + dedicated section
  docs/CONFIGURATION.md — new Client Authentication section with three
    method subsections, OpenSSL keygen snippet, RFC links
  docs/index.html — 5 new config-table rows + Private Key JWT
    explainer card
  .traefik.yml + examples/complete-traefik-config.yaml — commented
    opt-in example

Out of scope (deferred)
=======================
mTLS / tls_client_auth (RFC 8705) — separate change; requires per-call
http.Client with tls.Config.Certificates and conflicts with the current
pooled HTTP client architecture.
2026-05-09 18:02:41 +01:00
lukaszraczylo bfd702a447 fix(jwk): keep parsed JWKS in local cache only (#134) (#136)
Under yaegi (Traefik's plugin runtime) json.Marshal exposes unexported
struct fields with an X-prefixed name. parsedJWKS{ keys map[string]
crypto.PublicKey } therefore round-tripped through Redis as
{"Xkeys":{"<kid>":{"N":<huge>,"E":65537}}} — *rsa.PublicKey.N is a
*big.Int that marshals to a JSON number hundreds of digits long. On
read, json.Unmarshal into interface{} parses numbers as float64, which
cannot represent that range:

  Failed to deserialize value for key .../discovery/v2.0/keys:parsed:
  json: cannot unmarshal number 2251513...
    into Go value of type float64

Auth still worked (the JWKCache rebuilt the keys in memory on every
miss) but the error log spammed every request.

Two structural problems were behind it:

* parsedJWKS holds crypto.PublicKey interface values that aren't
  meaningfully JSON-serializable. Even on compiled Go (where the
  unexported field marshals to {}), the post-roundtrip type assertion
  v.(*parsedJWKS) silently failed and the cache was useless.
* The same pattern applied to *JWKSet — the struct shape survived JSON
  but the type assertion still failed, defeating the cache for every
  call that went through Redis.

Both keys now use the new UniversalCache.SetLocal/GetLocal pair, which
skips the configured distributed backend entirely. JWK rotation is rare
and a per-replica HTTP fetch on cold cache is cheap, so cross-replica
coherence buys nothing for these entries.

Stale Redis entries written by previous versions are simply ignored —
the new code never reads under those keys, and Redis TTL retires them.

Includes regression coverage for the Azure round-trip, the
poisoned-stale-data scenario, and the SetLocal/GetLocal isolation
contract.

patch-release
2026-05-08 13:35:23 +01:00
lukaszraczylo 68c150eba4 fix(cache/redis): honor enableTLS for Redis backend (#133)
The redis.enableTLS / redis.tlsSkipVerify settings were accepted by the
config layer but silently dropped before reaching the connection pool, so
the plugin always dialed Redis in plaintext. This blocked TLS-only Redis
deployments such as AWS ElastiCache with in-transit encryption.

- Add EnableTLS, TLSSkipVerify, TLSServerName to backends.Config and
  PoolConfig and forward them through universal_cache_singleton ->
  backends.Config -> PoolConfig.
- In the connection pool, dial via tls.Dialer.DialContext (TLS 1.2
  minimum) with SNI defaulting to the host part of the configured
  Address when TLSServerName is empty, so ElastiCache cluster endpoints
  validate out of the box. Plain dial path now also propagates ctx.
- Add regression tests covering successful TLS negotiation with skip-
  verify, rejection of self-signed certs without skip-verify, rejection
  of plain TCP servers when EnableTLS=true, and unaffected plaintext
  behavior.
- Document maxRefreshTokenAgeSeconds (added in 1b6c861) and the implicit
  SSE / WebSocket auth bypass (added in 684a990) in README.md,
  docs/CONFIGURATION.md and docs/index.html.
- Add the missing redis.tlsSkipVerify row to docs/index.html and clarify
  the redis.enableTLS description.

patch-release
2026-05-07 12:24:13 +01:00
lukaszraczylo 9cbca4c4fb fix(refresh): honor userIdentifierClaim in token refresh path (#132)
patch-release

The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.

The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.

Mirror the callback logic in refreshToken so both paths behave the same.

Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").

Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
2026-05-07 09:21:41 +01:00
lukaszraczylo 684a990f59 fix: reduce yaegi CPU footprint + require auth on SSE/WebSocket bypass
minor-release

Behaviour changes (potentially breaking for operators relying on the prior
unauthenticated SSE bypass):

* SSE (`Accept: text/event-stream`) and WebSocket upgrade requests now
  return 401 when no authenticated session is present. Previously the
  bypass forwarded unconditionally, which let any caller reach the
  backend by setting the right header. Excluded URLs are unchanged.
  Operators relying on unauthenticated SSE/WS access must move the path
  into ExcludedURLs.

Performance fixes (target: long-running dashboards like Grafana / ArgoCD
where many panels poll concurrently while the page stays open):

* Stop honouring isTestMode() for the singleton-token-cleanup interval
  under yaegi (the Traefik plugin runtime). In production the plugin was
  running a 20 Hz no-op cleanup ticker because runtime.Compiler ==
  "yaegi" tripped the test-mode branch.
* processAuthorizedRequest now resolves ID-token claims at most once per
  request via SessionData.GetIDTokenClaims (already cached on the
  session) and reuses them for both groups/roles extraction and
  header-template rendering. Previously every authenticated request
  parsed the JWT twice.
* Added extractGroupsAndRolesFromClaims to drive groups/roles off
  pre-parsed claims; extractGroupsAndRoles still works for tests.
* Removed the unconditional session.MarkDirty() in the header-templates
  branch. Templates only mutate request headers, not session state, so
  the prior MarkDirty was re-encrypting and rewriting all session
  cookies on every authenticated request that used header templates.

Other:

* Added isWebSocketUpgrade (RFC 6455 handshake detection — Connection:
  Upgrade + Upgrade: websocket, tolerant of multi-token Connection
  headers and case).
* Renamed applySSEUserHeaders -> applyBypassUserHeaders; it now returns
  bool so the dispatcher can reject unauthenticated SSE/WS with 401.
* Added tests for SSE and WS bypass covering both the auth-rejection
  path and the authenticated forward path.
2026-05-02 03:12:20 +01:00
lukaszraczylo 1b6c8616fd fix(refresh): coalesce refresh-token grants + bound goroutines + cache hot path (target v0.8.27) (#131)
* fix(refresh): wire RefreshCoordinator into the live refresh path

The RefreshCoordinator existed but was never instantiated. The actual
refresh path used only session.refreshMutex, which is per-SessionData
instance - and SessionData is pulled from a sync.Pool per request -
so concurrent requests sharing a refresh token had ZERO coordination.

Symptom: when access_token expired (e.g. 5min Zitadel default), every
in-flight request from a polling client (Grafana panels) entered the
refresh path simultaneously and POSTed the same refresh_token to the
IdP. With refresh-token rotation enabled (Zitadel/Authentik default),
only one grant succeeded; the rest got invalid_grant and each cleared
the entire session. Subsequent requests then thrashed in re-auth loops.

This commit:
- adds refreshCoordinator field on TraefikOidc
- instantiates it in NewWithContext with DefaultRefreshCoordinatorConfig
- shuts it down in Close() under shutdownOnce
- routes refreshToken() through the coordinator via coordinatedTokenRefresh,
  which collapses concurrent grants to a single upstream call per
  refresh_token hash
- exports refreshCoordinatorSessionID for both internal hashing and the
  middleware-level wireup so dedup keys stay aligned

Behavioural notes:
- nil-coordinator fallback preserves existing tests that build TraefikOidc
  literals without going through the constructor
- followers receive the same TokenResponse/error as the leader, so no
  per-instance code paths change
- existing TestGetNewTokenWithRefreshToken_Concurrency still passes
  because it hits GetNewTokenWithRefreshToken directly, below the
  coordinator boundary

Tests:
- refresh_coordinator_wireup_test.go: 50 concurrent refreshes coalesce
  to <=2 upstream calls; distinct tokens still run in parallel; nil
  coordinator falls back cleanly

* perf(cache): bound L1 backfill goroutines in HybridBackend

Get() and GetMany() previously spawned a goroutine per L2 hit to write
the value through to L1. Under sustained polling traffic (e.g. a Grafana
dashboard refreshing every 30s with N panels) this minted thousands of
goroutines, each running in Yaegi - directly contributing to the
~1000% CPU spike that pairs with the refresh-token herd.

Replace the per-hit goroutines with a single l1BackfillWorker fed by
l1BackfillBuffer, mirroring the existing asyncWriteBuffer/asyncWriteWorker
pattern for L2 writes. Buffer overflow drops the backfill (counted via
l1BackfillDrops) - a dropped backfill just means the next L2 hit for
that key re-queues it, which is safe.

Tests:
- TestHybridBackend_L1BackfillBounded: 1000 distinct L2 hits keep
  goroutine count within +20 of baseline (pre-fix it grew by ~1000)
- TestHybridBackend_L1BackfillFullDrops: drops are accounted for when
  the buffer is saturated and the worker is stopped

* feat(refresh): implement isRefreshTokenExpired heuristic

Replace the placeholder `return false` with a real check based on the
issued_at timestamp that SetRefreshToken already stamps into the session.
Gated by a new MaxRefreshTokenAgeSeconds config field (default 21600 =
6h, matching the existing comment). 0 disables the check.

This wires the previously-dead refreshTokenExpired branch in middleware.go,
which short-circuits AJAX requests with a 401 instead of letting them
hammer the IdP for a refresh token that's almost certainly stale - the
classic Grafana-after-long-pause failure mode.

Behaviour:
- maxRefreshTokenAge=0 disables the check (preserves prior behaviour)
- legacy sessions without issued_at still attempt one refresh; the IdP
  remains the source of truth on first try
- nil-receiver and nil-session guards keep test code that builds
  TraefikOidc literals safe

Tests:
- TestIsRefreshTokenExpired_DisabledWhenAgeZero
- TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp
- TestIsRefreshTokenExpired_WithinWindow
- TestIsRefreshTokenExpired_BeyondWindow
- TestIsRefreshTokenExpired_NilGuards

* perf(token): skip parseJWT on cache hit in VerifyToken

The token cache fast-return existed but ran AFTER parseJWT, so every
validation paid for base64 + JSON unmarshal even on a hit. Under bursty
traffic (e.g. 10+ concurrent panel requests on every Grafana dashboard
refresh, each calling validateStandardTokens which verifies BOTH the
access token and the ID token), this is two redundant parses per
request multiplied by the panel count.

Move the cache lookup ahead of parseJWT. On a hit the function returns
nil immediately. On a miss the original flow runs unchanged.

Also nil-guard t.tokenCache to keep partial-literal test instances safe
(matches the same pattern we already use for tokenBlacklist).

Tests:
- TestVerifyToken_CacheHitSkipsParse: cache pre-populated with claims
  for a token whose body would fail parseJWT - returns nil iff the
  fast-path bypasses the parse
- TestVerifyToken_CacheMissStillParses: a syntactically valid but
  unsigned token still errors past parseJWT on cache miss

* feat(refresh): cross-replica refresh-grant dedup via shared cache

The in-process RefreshCoordinator added in 9f96d8c already collapses
concurrent refresh-token grants on a single Traefik replica. With the
plugin's existing Redis (Dragonfly) cache infrastructure available, we
can extend that dedup across replicas: if pod A refreshes a token at
T+0 and pod B receives a request for the same session at T+1, pod B
should reuse pod A's result rather than POSTing the now-rotated refresh
token to the IdP.

Implementation:
- Add a refreshResultCache to UniversalCacheManager (memory-only when
  Redis is disabled, Redis-backed in production via the existing
  hybrid/Redis-only mode selection)
- Expose it through CacheManager.GetSharedRefreshResultCache and on the
  TraefikOidc struct as refreshResultCache (CacheInterface)
- Inside the closure passed to RefreshCoordinator.CoordinateRefresh,
  consult the cache first; on hit return immediately, on miss exchange
  with the IdP and populate the cache for peers
- 5s TTL: long enough for siblings to observe, short enough that a
  rotated refresh token cannot be re-supplied after the IdP has moved on
- Errors are intentionally NOT cached - peers must always be able to
  retry on their own

Pragmatic choice: optimistic cache rather than a hard distributed lock.
- A hard lock (SET NX + poll) doubles Redis RTT and risks dead-locks
  if a Traefik pod dies mid-grant.
- The user's BGP+Local externalTrafficPolicy already pins ingress for
  a session to one node in steady state, so cross-pod racing is rare.
- This optimistic path catches the rare failover case without adding
  failure modes.

Tests:
- TestCoordinatedTokenRefresh_CrossReplicaCacheHit: pre-populated cache
  short-circuits the upstream call entirely (0 IdP calls)
- TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache: leader stores
  a successful result for peers to find
- TestCoordinatedTokenRefresh_ErrorIsNotCached: invalid_grant must not
  poison the dedup cache - peers must retry independently
2026-04-30 18:52:39 +01:00
lukaszraczylo 4d28fa01ab perf(jwk,cache): cache parsed public keys + RLock token cache reads
Hot-path JWT verification rebuilt the public key on every call:
  jwk -> ToRSAPublicKey -> x509.MarshalPKIXPublicKey -> pem.Encode
  -> verifySignature -> pem.Decode -> x509.ParsePKIXPublicKey -> verify
Under yaegi this pinned a CPU when many concurrent dashboard panels
poll behind the middleware. The PEM round trip is pure waste.

* jwk.go: cache pre-parsed crypto.PublicKey per kid alongside the
  raw JWKSet (parallel cache entry, same 1h TTL, invalidates together).
* jwt.go: split verifySignatureWithKey from verifySignature; existing
  PEM-input entry point preserved for backchannel-logout callers.
* token_manager.go: VerifyJWTSignatureAndClaims now goes straight from
  jwks cache to verifySignatureWithKey, no PEM round trip and no
  per-request availableKids slice.
* universal_cache.go: token/JWK/session Get() takes RLock when the
  entry is unexpired, so concurrent token verifications no longer
  serialize on a single mutex. LRU semantics for general and metadata
  caches are unchanged (tests cover the strict-LRU contract there).
* mocks: MockJWKCache, EnhancedMockJWKCache, mockJWKCacheForLogout,
  staticJWKCache satisfy the extended interface.
2026-04-30 10:14:10 +01:00
lukaszraczylo 2d1b04c637 review fixes apr 2026 (#130)
* Multiple fixes

- refresh coordinator dedup + memory pressure wire
- middleware sse consolidation + timer leak + claim cache
- universal cache sync backfill + isDebug gate
- lazy background task race
- memory monitor stw cached + refresh() api

* fix(auth): suppress OIDC redirects on non-navigation requests

- [x] Add isNonNavigationRequest using Sec-Fetch-Mode and Accept headers
- [x] Add comprehensive TestIsNonNavigationRequest
- [x] Update ServeHTTP to 401 non-navigation and AJAX requests

Fixes #129

* feat(config): add custom CA and insecure skip verify for OIDC TLS

- [x] Add CACertPath, CACertPEM, InsecureSkipVerify to Config
- [x] Implement loadCACertPool for CA bundle loading
- [x] Update HTTPClientConfig with RootCAs and InsecureSkipVerify
- [x] Apply CA pool and skip verify to pooled HTTP clients
- [x] Enhance configKey to distinguish TLS configs
- [x] Add comprehensive ca_cert_test.go

Fixes #125

* feat(oidc): add custom CA certificate support for private OIDC providers

- [x] Add caCertPath, caCertPEM, insecureSkipVerify config options
- [x] Update traefik.yml with new OIDC client config fields
- [x] Add configuration schema descriptions for new options
- [x] Update README table and add Custom CA Certificates section

* Fix the documentation.

* test(redis): add oversized argument rejection test

- [x] Add TestRedisConn_RejectOversizedArgumentBytes
- [x] Import strings package

* Dependencies cleanup
2026-04-19 10:12:00 +01:00
lukaszraczylo ccbb98b9dd fix-issue-122 (#128) 2026-03-04 00:23:30 +00:00
Serhii Vasyliev 1362cc0dac Improve debug logging around callback URL matching (#126)
* Add debug logging around callback URL matching in ServeHTTP

The callback URL comparison at the core of OIDC flow had zero logging,
making it extremely difficult to diagnose redirect loop issues caused
by misconfigured callbackURL (e.g., full URL vs path-only).

Every other path comparison in ServeHTTP already logs debug info
(logout, backchannel, frontchannel, excluded URLs), but the callback
URL check was completely silent.

Added debug logs that show:
- The values being compared (request path vs configured callback)
- Whether the match succeeded or failed
- Configured redirURLPath during initialization

This would have immediately revealed the root cause of issue #1
where callbackURL was set as a full URL but compared against
req.URL.Path which only contains the path component.

Closes #3

* improve-callback-url-logging: Add init-time logging for callbackURL config
2026-02-23 10:36:37 +00:00
Yuval Bar-On 249dcad1b3 fix: prevent deadlock in SessionData.Clear method (#114)
Move mutex unlock before calling Save() to prevent potential deadlock
when Save() method needs to acquire the same mutex.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-16 15:02:33 +00:00
lukaszraczylo de4b4d7258 fix(cache): remove sync.Pool for Yaegi compatibility (#121)
- [x] Remove sync.Pool implementation that causes reflection panics
- [x] Replace pool-based NewRESPWriter with direct instantiation
- [x] Replace pool-based NewRESPReader with direct instantiation
- [x] Convert Release() methods to no-ops for API compatibility
- [x] Add documentation explaining sync.Pool removal for Yaegi
- [x] Remove "sync" import

Resolves #120
2026-01-19 17:52:31 +00:00
lukaszraczylo 9d52f1b018 feat(core): refactor linters config and improve code quality (#119)
- [x] Reorganize golangci-lint configuration with documented disable reasons
- [x] Simplify errcheck and revive linter rules with targeted exclusions
- [x] Pre-compile regex patterns in input_validation.go for performance
- [x] Fix type assertions in memory_shard.go and resp.go with safety checks
- [x] Replace string comparison with EqualFold for case-insensitive matching
- [x] Fix loop variable captures in jwk.go and logout.go
- [x] Change high goroutine log level from Info to Debug in autocleanup.go
- [x] Replace deprecated "cancelled" spelling with "canceled" throughout
- [x] Add nolint annotations for intentional unused parameters
- [x] Improve comment formatting for deprecated functions
- [x] Fix comment spelling: "marshalling" → "marshaling"
- [x] Refactor provider warnings formatting in internal/providers/warnings.go
- [x] Simplify metrics summary building in internal/recovery/metrics.go
- [x] Pre-allocate slice in error_recovery.go GetDegradedServices
- [x] Refactor context cancellation checks in redis.go
2026-01-15 10:40:49 +00:00
lukaszraczylo 57724918fe fix 116 (#118)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases

* docs: add security fix documentation for integer overflow protection

* test: fix goroutine tests to use mock OIDC servers

The TestContextAwareGoroutineManagement tests were making real HTTP
calls to hardcoded URLs like https://example.com, causing failures
in CI when those requests timeout or return HTTP errors.

Changes:
- Added createMockOIDCServer() helper function using httptest
- Updated GoroutineCleanupOnContextCancel to use mock server
- Updated NoGoroutineLeakOnMultipleInstances to use 3 mock servers
- Updated SingletonTasksAcrossInstances to use mock servers array

This prevents network calls and makes tests more reliable and faster.

Fixes test failures in GitHub Actions CI.
2026-01-08 22:50:46 +00:00
lukaszraczylo 775de2ada1 Fix cache serialisation (#117)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases
2026-01-08 22:06:19 +00:00
lukaszraczylo 7816e05c98 fix issue with logout url (#112)
* fix(logout): handle logout requests before OIDC initialization

- [x] Add debug logging to logout handler entry point
- [x] Move logout path check before OIDC initialization to enable logout when provider unavailable
- [x] Move excluded URL and SSE checks before initialization wait
- [x] Add debug logging for initialization wait to diagnose hanging requests
- [x] Add test for logout functionality without OIDC provider availability

* feat(logout): implement OIDC backchannel and front-channel logout

- [x] Add logout token validation and backchannel logout handler
- [x] Add front-channel logout handler with iframe support
- [x] Implement session invalidation cache for distributed deployments
- [x] Add comprehensive logout token claim verification (issuer, audience, events, iat, sid/sub)
- [x] Integrate session invalidation checks into authorization flow
- [x] Add configuration options for enabling backchannel/front-channel logout
- [x] Add extensive test coverage for logout flows and edge cases
- [x] Update documentation with logout configuration examples
- [x] Add middleware routing for logout endpoints
- [x] Extend cache manager with session invalidation cache support

Resolves #110

* fixup! feat(logout): implement OIDC backchannel and front-channel logout

* fixup! Merge branch 'main' into fix-issue-with-logout-url
2026-01-04 01:59:50 +00:00
Dominik Chilla 8bf7998150 Fix for Hashicorp Vault - accept opaque access tokens with dot-characters (#113) 2026-01-02 16:42:22 +00:00
muffn_ 22c4323fcb fix: set X-Forwarded-User header for SSE requests from existing session (#111)
Co-authored-by: muffin <MonsterMuffin@users.noreply.github.com>
2026-01-02 02:50:11 +00:00
lukaszraczylo 06b219d1f8 feat(dcr): Add Redis storage support for multi-replica deployments (#109)
- [x] Add file and Redis storage backends for DCR credentials
- [x] Implement storage abstraction with FileStore and RedisStore
- [x] Add factory function for automatic backend selection (auto/file/redis)
- [x] Integrate DCR credentials cache into UniversalCacheManager
- [x] Add comprehensive tests for storage backends and factory
- [x] Update configuration schema with storage backend options
- [x] Update documentation with multi-replica deployment guidance
- [x] Add Redis key prefix configuration for credential isolation
2025-12-31 12:52:39 +00:00
lukaszraczylo 413e4a1b7d LRU + cache conflicts prevention. (#104)
* LRU + cache conflicts prevention.

* Bugfix universalCache flooding ( issue #105 )

  1. Traefik cancels the context for old plugin instances
  2. Each plugin's Close() method is called
  3. The CacheInterfaceWrapper.Close() was calling cache.Close() on the shared singleton caches
  4. Each Close() triggered Clear() which logged "Cleared all items" at INFO level
2025-12-24 18:54:39 +00:00
lukaszraczylo 69e0d98c67 fixup! Add signing of the plugin on release. 2025-12-24 12:33:33 +00:00
lukaszraczylo 6d893df12b Add signing of the plugin on release. 2025-12-15 00:38:35 +00:00
lukaszraczylo 6efb78b7a8 Smarter approach to the cookies (#103)
* Smarter approach to the cookies

  - Single maxCookieSize = 1400 constant with clear documentation
  - Combined cookie storage for ~40-45% size reduction
  - Backward compatible migration from legacy cookies

* Tuneup the code.
2025-12-12 18:35:06 +00:00
lukaszraczylo d0b920c4f0 multiple realms fix (#102)
* Allow to use multiple realms

This change is a ressurection of PR #88 which can't be merged due to significant refactor of the codebase.

* Fix the autocleanup routine to handle multiple realms correctly, update tests.

* Metadata rediscovery when provider is unavailable for any reason during the start.

This one prevents the permanent 503 from the plugin when OIDC provider was for some reason unavailable during the start.
2025-12-10 13:07:22 +00:00
lukaszraczylo c474bbafd6 Cleanup [dec2025] (#101)
* Cleanup excessive comments.

* Remove leftovers hanging around from previous refactor

* Improve test coverage
2025-12-09 01:38:02 +00:00
lukaszraczylo 9126c74723 December 2025 Improvements - Azure AD, Internal Networks, Startup Race Condition (#100)
* Allow internal IPs for OIDC configuration via extra flag.

Addresses issue #97

* Allow for internal IPs in OIDC configuration.

Addresses issue #97.

* feat: Add allowPrivateIPAddresses config option for internal networks

Adds a new configuration option `allowPrivateIPAddresses` that allows
OIDC provider URLs to use private IP addresses (10.x.x.x, 172.16-31.x.x,
192.168.x.x). This is useful for internal deployments where Keycloak or
other OIDC providers run on private networks without DNS resolution.

Security considerations:
- Loopback addresses (127.0.0.1, localhost, ::1) remain blocked
- Link-local addresses (169.254.x.x) remain blocked
- Default is false (secure by default)

Fixes #97

* feat: Support non-email user identifiers for Azure AD

Add userIdentifierClaim configuration option to support Azure AD users
without email addresses. This allows using alternative JWT claims like
"sub", "oid", "upn", or "preferred_username" for user identification.

- Default behavior uses "email" claim (backward compatible)
- Falls back to "sub" claim if configured claim is missing
- allowedUsers matches against the configured claim value
- allowedUserDomains only applies when using email-based identification

Fixes #95

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Headers too big and 431 responses

Added new option `minimalHeaders` to reduce the size of forwarded headers from the auth middleware to backend services.

  - When minimalHeaders: false (default): All headers are forwarded as before
    - X-Forwarded-User (always set)
    - X-Auth-Request-Redirect
    - X-Auth-Request-User
    - X-Auth-Request-Token (the large ID token)
    - X-User-Groups, X-User-Roles (if configured)
  - When minimalHeaders: true: Reduces header overhead
    - X-Forwarded-User (always set)
    - X-User-Groups, X-User-Roles (still forwarded if configured)
    - Custom templated headers (still processed)
    - Skipped: X-Auth-Request-Token, X-Auth-Request-User, X-Auth-Request-Redirect

Fixes issues #64 and #86
2025-12-08 14:21:17 +00:00
lukaszraczylo a750c4f5b9 Size computation for allocation may overflow (#99)
* Size computation for allocation may overflow

Performing calculations involving the size of potentially large strings or slices can result in an overflow (for signed integer types) or a wraparound (for unsigned types). An overflow causes the result of the calculation to become negative, while a wraparound results in a small (positive) number.
2025-12-08 11:22:28 +00:00
lukaszraczylo 56051779ee Hotfix: goreleaser archive format. 2025-12-08 02:39:40 +00:00
lukaszraczylo 3f126d50f3 Force the v in the release tags and name. 2025-12-08 02:34:10 +00:00
lukaszraczylo 91f0fc9ab8 Switch to go releaser 2025-12-08 02:32:46 +00:00
lukaszraczylo 66b9ed0861 Reauthentication + redis fix
When introspection explicitly returns that a token is inactive/revoked/expired, the plugin now properly triggers re-authentication or refresh instead of falling back to ID token validation. This fixes the functional issue where users
weren't being redirected to re-authenticate.
Redis change ensures that when the caller's context is cancelled (e.g., the 200ms timeout in UniversalCache.Get()), the operation aborts quickly instead of continuing with retries.
2025-12-01 13:47:28 +00:00
lukaszraczylo e64fc7f730 Add redis support for distributed caching (#83)
* Add redis support for distributed caching

* Move towards the self-provided Redis connection pool and RESP protocol implementation.
Official redis client library won't work with yaegi.

* fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* ... and another all nighter.

* fixup! ... and another all nighter.

* fixup! fixup! ... and another all nighter.

* fixup! fixup! fixup! ... and another all nighter.

* Resolve issue #85 by adding ability to set custom claims in JWT tokens

* Remove redundant validation in auth middleware ( issue #89 )

* Add ability to set cookie prefix for session cookies ( #87 )

* fixup! Add ability to set cookie prefix for session cookies ( #87 )

* Add ability to set cookie max age - issue #91

* Potential fix for code scanning alert no. 10: Size computation for allocation may overflow

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fixup! Merge main into 0.8.0-redis: resolve conflicts

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-11-30 02:18:46 +00:00
lukaszraczylo 5fcbd54955 Add sharded cache and prevention of CPU spikes / locks (#96)
* Add sharded cache and prevention of CPU spikes / locks

* Add dynamic client registration with oidc provider

* Fix race condition introduced during the sharded cache implementation.

* Add page for traefikoidc.
2025-11-30 01:41:12 +00:00
lukaszraczylo e70cd1907c Create CNAME 2025-11-30 01:28:07 +00:00
lukaszraczylo e45b06c86d Fix markdown issues. 2025-10-17 14:40:50 +01:00
lukaszraczylo ae59a5e88a 0.7.10 (#80)
* Add ability to disable replay protection. - This is useful for runs with multiple traefik replicas to avoid false positives and tokens re-creation.
* Enhance the CI/CD pipelines
* Increase test coverage.
* Update vendored dependencies.
* Update behaviour on forceHTTPS as per issue #82
2025-10-16 10:56:28 +01:00
lukaszraczylo 79e9b164f9 release 0.7.9 (#78)
* Speed improvements.

After introduction of introspection the plugin became significantly slower.
This commit introduces several optimizations to bring the speed back up.

* Add relevant documentation and tests.
2025-10-13 10:43:35 +01:00
lukaszraczylo 93888e56d1 fixup! Multiple issues addressed (#76) 2025-10-09 00:56:53 +01:00
lukaszraczylo eff9bd7bd2 Multiple issues addressed (#76)
- Issue #74
- Issue #14
2025-10-09 00:44:03 +01:00
lukaszraczylo bde1db1c3b traefik plugin 0.7.7 (#73)
* Automatic discovery of the scopes.

Issue #61 raised very valid concerns about users configuring scopes that are not supported by the provider.
This change introduces automatic discovery of supported scopes by fetching the provider's discovery document and filtering out unsupported scopes.

Before:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Self-hosted GitLab: "The requested scope is invalid, unknown, or malformed"
Authentication:  FAILS

After:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Middleware checks discovery doc → offline_access not supported
Automatically filters to: ["openid", "profile", "email"]
Authentication:  SUCCEEDS

* Resolves issue #74 by enabling user to specify expected audience in the configuration.

* Fix flaky tests.
2025-10-08 11:44:00 +01:00
lukaszraczylo 79d34ea4c9 Fix recursion in token resilience logic (#72) 2025-10-07 10:34:15 +01:00
lukaszraczylo c3f23cb99b Release 0.7.5 (#70)
* Resolve issue with opaque tokens not being parsed correctly

* Increase test coverage

* Further improvements to test coverage and code quality

* Add new providers.

* fixup! Add new providers.

* Cleanup.

* fixup! Cleanup.

* fixup! fixup! Cleanup.

* fixup! fixup! fixup! Cleanup.

* fixup! fixup! fixup! fixup! Cleanup.

* Memory management optimisation

24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference)

* Pooling cleanup.
2025-10-01 12:13:10 +01:00
599 changed files with 171409 additions and 30488 deletions
+38
View File
@@ -0,0 +1,38 @@
# Code Owners for traefik-oidc
# These owners will be automatically requested for review when someone opens a PR
# Default owner for everything in the repo
* @lukaszraczylo
# Core authentication and middleware
/middleware/ @lukaszraczylo
/auth/ @lukaszraczylo
/handlers/ @lukaszraczylo
# OIDC providers
/internal/providers/ @lukaszraczylo
# Session management and security
/session/ @lukaszraczylo
/internal/security/ @lukaszraczylo
/security/ @lukaszraczylo
# Token management
/internal/token/ @lukaszraczylo
# Configuration
/config/ @lukaszraczylo
/.traefik.yml @lukaszraczylo
# GitHub Actions and CI/CD
/.github/ @lukaszraczylo
/.github/workflows/ @lukaszraczylo
/.golangci.yml @lukaszraczylo
# Documentation
/docs/ @lukaszraczylo
README.md @lukaszraczylo
# Dependencies
go.mod @lukaszraczylo
go.sum @lukaszraczylo
+15
View File
@@ -0,0 +1,15 @@
# These are supported funding model platforms
github: lukaszraczylo
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
thanks_dev: # Replace with a single thanks.dev username
custom: https://monzo.me/lukaszraczylo
+123
View File
@@ -0,0 +1,123 @@
## Description
<!-- Provide a brief description of the changes in this PR -->
## Type of Change
<!-- Mark the relevant option with an "x" -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Documentation update
- [ ] Performance improvement
- [ ] Code refactoring
- [ ] Security fix
- [ ] Provider-specific fix/enhancement
## Related Issues
<!-- Link to related issues using #issue_number -->
Fixes #
Related to #
## Changes Made
<!-- List the main changes made in this PR -->
-
-
-
## Provider Impact
<!-- If this affects specific OIDC providers, list them here -->
- [ ] Google
- [ ] Azure AD
- [ ] Auth0
- [ ] Okta
- [ ] Keycloak
- [ ] AWS Cognito
- [ ] GitLab
- [ ] GitHub
- [ ] Generic OIDC
- [ ] All providers
## Testing Performed
<!-- Describe the tests you ran to verify your changes -->
- [ ] Unit tests pass locally
- [ ] Integration tests pass locally
- [ ] Race detector shows no issues
- [ ] Memory leak tests pass
- [ ] Manual testing performed
### Test Configuration
<!-- Provide details about your test configuration if applicable -->
**Provider tested:**
**Go version:**
**Traefik version:**
## Security Considerations
<!-- Describe any security implications of these changes -->
- [ ] This PR does not introduce security vulnerabilities
- [ ] Security scanning has been performed
- [ ] Credentials/secrets are properly handled
- [ ] Input validation is implemented
## Performance Impact
<!-- Describe any performance implications -->
- [ ] No performance impact expected
- [ ] Performance improved (describe how)
- [ ] Performance may be affected (describe why and mitigation)
## Breaking Changes
<!-- If this is a breaking change, describe the impact and migration path -->
**Breaking changes:**
**Migration guide:**
## Checklist
<!-- Ensure all items are checked before requesting review -->
- [ ] My code follows the project's code style
- [ ] I have performed a self-review of my code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged and published
## Additional Context
<!-- Add any other context, screenshots, or information about the PR here -->
## Screenshots (if applicable)
<!-- Add screenshots to help explain your changes -->
---
**For Reviewers:**
Please verify:
- [ ] Code quality and style
- [ ] Test coverage is adequate
- [ ] Security implications reviewed
- [ ] Documentation is updated
- [ ] No performance regressions
+52
View File
@@ -0,0 +1,52 @@
version: 2
updates:
# Maintain dependencies for GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
open-pull-requests-limit: 5
commit-message:
prefix: "chore(deps)"
include: "scope"
labels:
- "dependencies"
- "github-actions"
reviewers:
- "lukaszraczylo"
# Maintain Go module dependencies
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
open-pull-requests-limit: 10
commit-message:
prefix: "chore(deps)"
include: "scope"
labels:
- "dependencies"
- "go"
reviewers:
- "lukaszraczylo"
# Group patch updates together
groups:
patch-updates:
patterns:
- "*"
update-types:
- "patch"
minor-updates:
patterns:
- "*"
update-types:
- "minor"
# Ignore certain dependencies if needed
ignore:
# Example: ignore specific versions
# - dependency-name: "github.com/example/package"
# versions: ["1.x", "2.x"]
+9
View File
@@ -0,0 +1,9 @@
# Ensure consistent line endings
* text=auto eol=lf
# GitHub Actions files should use LF
*.yml text eol=lf
*.yaml text eol=lf
# Shell scripts should use LF
*.sh text eol=lf
+225
View File
@@ -0,0 +1,225 @@
# GitHub Actions Workflows
This directory contains CI/CD workflows for the Traefik OIDC middleware.
## Workflows
### PR Validation (`pr-validation.yml`)
A comprehensive validation workflow that runs **all checks in parallel** for maximum speed and thorough testing.
**Triggered on:**
- Pull requests to `main` branch
- Pushes to `main` branch
**Parallel Jobs (20+ concurrent checks):**
#### Code Quality
- **Quick Checks** - Format, go vet, go mod verify
- **golangci-lint** - Comprehensive linting
- **Staticcheck** - Static analysis
#### Security
- **Gosec** - Security vulnerability scanning
- **Govulncheck** - Go vulnerability database check
- **CodeQL** - GitHub's code analysis
#### Testing
- **Race Detector** - Concurrent access bug detection
- **Coverage** - Test coverage with 75% threshold
- **Memory Leaks** - Goroutine and memory leak detection
- **Integration Tests** - Full integration test suite
- **Regression Tests** - Prevent previously fixed bugs
- **Security Edge Cases** - Security-specific scenarios
- **Session Tests** - Session management validation
- **Token Tests** - Token validation scenarios
- **CSRF Tests** - CSRF protection validation
#### Provider Testing (Matrix)
Tests run in parallel for each OIDC provider:
- Google
- Azure AD
- Auth0
- Okta
- Keycloak
- AWS Cognito
- GitLab
- GitHub
- Generic OIDC
#### Performance & Compatibility
- **Benchmarks** - Performance regression detection
- **Build Matrix** - linux/darwin × amd64/arm64
- **Go Versions** - Go 1.23 and 1.24 compatibility
#### Final Validation
- **All Checks Passed** - Ensures all jobs succeeded
## Workflow Features
### 🚀 Parallel Execution
All independent checks run simultaneously for fastest feedback (~5-10 minutes for full suite).
### 📊 Coverage Reporting
- Automatic PR comments with coverage statistics
- Per-package coverage breakdown
- 75% coverage threshold enforcement
### 🔒 Security First
- Multiple security scanners (gosec, govulncheck, CodeQL)
- SARIF report uploads for GitHub Security tab
- Security edge case testing
### 🎯 Comprehensive Testing
- Race condition detection
- Memory leak detection
- Provider-specific testing
- Integration and regression tests
### 📈 Performance Tracking
- Benchmark results stored as artifacts
- Performance regression detection
### ✅ Quality Gates
All checks must pass before PR can be merged:
- Code formatting and style
- Security vulnerabilities
- Test coverage threshold
- Race conditions
- Memory leaks
- Build success on all platforms
## Local Development
### Run checks locally before pushing:
```bash
# Format code
gofmt -s -w .
# Run linter
golangci-lint run
# Run tests with race detector
go test -race -timeout=15m -count=1 ./...
# Check coverage
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
# Run specific test suites
go test -v -run='.*Leak.*' ./... # Memory leak tests
go test -v -run='.*Integration.*' ./... # Integration tests
go test -v -run='.*Regression.*' ./... # Regression tests
# Run benchmarks
go test -bench=. -benchmem ./...
# Security scan
gosec ./...
govulncheck ./...
```
### Required Tools
Install these tools for local development:
```bash
# golangci-lint
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# staticcheck
go install honnef.co/go/tools/cmd/staticcheck@latest
# gosec
go install github.com/securego/gosec/v2/cmd/gosec@latest
# govulncheck
go install golang.org/x/vuln/cmd/govulncheck@latest
```
## Troubleshooting
### Workflow Fails
1. **Check job status** - Click on failed job for details
2. **Review logs** - Expand failed steps to see error messages
3. **Run locally** - Reproduce issue with local commands above
4. **Check coverage** - Ensure test coverage meets 75% threshold
### Coverage Below Threshold
Add tests to increase coverage:
```bash
# See which lines aren't covered
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out
```
### Race Condition Detected
Run with race detector locally:
```bash
go test -race -v ./...
```
### Provider Test Failure
Test specific provider:
```bash
go test -v -run='.*Azure.*' ./internal/providers/...
```
## Performance Optimization
The workflow is optimized for speed:
- **Parallel execution** - All independent jobs run simultaneously
- **Go caching** - Dependencies cached between runs
- **Strategic ordering** - Quick checks run first for fast feedback
- **Fail-fast disabled** - Continue running all tests even if some fail
## Workflow Monitoring
### GitHub Actions Dashboard
Monitor workflow runs at: `https://github.com/{owner}/{repo}/actions`
### Status Badges
Add to README.md:
```markdown
![PR Validation](https://github.com/{owner}/{repo}/actions/workflows/pr-validation.yml/badge.svg)
```
### Notifications
Configure in repository settings:
- Settings → Notifications
- Choose email or Slack notifications for workflow failures
## Maintenance
### Update Go Version
Edit in workflow file:
```yaml
go-version: '1.24' # Update this
```
### Adjust Coverage Threshold
Edit in workflow file:
```yaml
THRESHOLD=75 # Adjust this value
```
### Add New Provider
Add to provider matrix:
```yaml
matrix:
provider:
- new_provider # Add here
```
## Additional Resources
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
- [golangci-lint Configuration](../.golangci.yml)
- [Dependabot Configuration](../dependabot.yml)
- [PR Template](../PULL_REQUEST_TEMPLATE.md)
+23
View File
@@ -0,0 +1,23 @@
name: Pull Request
on:
pull_request:
branches:
- main
push:
branches:
- "**"
- "!main"
permissions:
contents: read
pull-requests: write
security-events: write
jobs:
pr-checks:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with:
go-version: "1.24.11"
coverage-threshold: 70
secrets: inherit
+23
View File
@@ -0,0 +1,23 @@
name: Release
on:
push:
branches:
- main
paths:
- "**.go"
- "go.mod"
- "go.sum"
workflow_dispatch:
permissions:
id-token: write
contents: write
packages: write
jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.24.11"
secrets: inherit
+3 -1
View File
@@ -1,2 +1,4 @@
docker/
.claude/
.claude/*.out
*.test
.leann/
+209
View File
@@ -0,0 +1,209 @@
version: "2"
run:
go: "1.24"
modules-download-mode: readonly
tests: true
linters:
enable:
- bodyclose
- dupl
- goconst
- gocritic
- gocyclo
- goprintffuncname
- gosec
- misspell
- noctx
- prealloc
- revive
- rowserrcheck
- sqlclosecheck
- unconvert
- unparam
disable:
- exhaustive
- funlen
- gocognit
- gocyclo # Disabled: OAuth/OIDC flows are inherently complex
- goprintffuncname # Disabled: naming convention is project-specific
- lll
- mnd
- testpackage
- whitespace # Disabled: style preference about newlines
- wsl
settings:
dupl:
threshold: 200 # Allow intentional duplication in provider patterns and token management
errcheck:
check-type-assertions: true
check-blank: false # Allow explicit blank assignments (_ = ...) to ignore errors
exclude-functions:
- (io.Closer).Close
- (*database/sql.Rows).Close
- (*database/sql.Stmt).Close
- (io.Writer).Write
- (*net/http.ResponseWriter).Write
- fmt.Fprintf
- fmt.Fprint
- fmt.Fprintln
goconst:
min-len: 3
min-occurrences: 15 # Increased to reduce noise for standard OAuth2/OIDC strings and common patterns like "true"
ignore-tests: true
gocritic:
# Disable style-only checks that add noise
disabled-checks:
- ifElseChain # Style preference, switch not always clearer
- elseif # Style preference
gocyclo:
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
gosec:
excludes:
- G104
- G404
severity: medium
confidence: medium
govet:
disable:
- fieldalignment
- shadow
enable-all: true
misspell:
locale: US
ignore-rules:
- traefik
- oidc
- keycloak
nolintlint:
require-explanation: true
require-specific: true
allow-unused: false
prealloc:
simple: true
range-loops: true
for-loops: false
revive:
rules:
- name: blank-imports
- name: context-as-argument
- name: context-keys-type
- name: dot-imports
- name: error-return
- name: error-strings
- name: error-naming
# - name: exported # Disabled: too noisy, not all exported functions need comments
# - name: if-return # Disabled: style preference
- name: increment-decrement
# - name: var-naming # Disabled: too strict for legacy code (IP vs Ip)
# - name: var-declaration # Disabled: explicit zero values can be clearer
# - name: package-comments # Disabled: handled by other tools
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
# - name: indent-error-flow # Disabled: style preference
- name: errorf
# - name: empty-block # Disabled: sometimes empty blocks are intentional
- name: superfluous-else
# - name: unused-parameter # Disabled: test callbacks and interface implementations often have required unused params
- name: unreachable-code
# - name: redefines-builtin-id # Disabled: min/max helpers are common before Go 1.21
unparam:
check-exported: false
staticcheck:
checks:
- all
- -QF1001 # De Morgan's law - style preference, may affect Yaegi
- -QF1003 # Tagged switch - style preference, may affect Yaegi
- -QF1007 # Merge conditional assignment - style preference
- -QF1008 # Remove embedded field - may break Yaegi compatibility
- -QF1011 # Omit type from declaration - style preference
- -QF1012 # Use fmt.Fprintf - style preference
- -SA9003 # Empty branch - sometimes intentional for future work
- -ST1000 # Package comment format - not required for all packages
- -ST1003 # Package name format - allowed for test packages
- -ST1016 # Receiver name consistency - legacy code
- -ST1020 # Comment format for methods - style preference
- -ST1021 # Comment format for types - style preference
- -ST1023 # Omit type from declaration - style preference
exclusions:
generated: lax
rules:
- linters:
- bodyclose
- dupl
- errcheck
- goconst
- gocyclo
- gosec
- govet
- ineffassign
- noctx
- prealloc
- unparam
- revive
- gocritic
path: _test\.go
- linters:
- dupl
- gocyclo
- govet
- noctx
- prealloc
- unparam
- revive
- gocritic
path: test.*\.go
- linters:
- gocritic
- unused
- errcheck
- revive
path: mocks.*\.go
- linters:
- errcheck
- revive
- gocritic
- govet
- unparam
path: internal/testutil/
- linters:
- govet
- unparam
- noctx
- prealloc
path: integration/
- linters:
- gosec
text: 'G404:'
- linters:
- all
path: vendor/
- linters:
- goconst
path: (.+)_test\.go
- linters:
- dupl
path: internal/providers/(auth0|keycloak|okta|google|azure|github|gitlab|cognito|generic)\.go
- linters:
- dupl
path: session\.go
- linters:
- dupl
path: session_chunk_manager\.go
text: "(extractJWTExpiration|extractJWTIssuedAt)"
paths:
- third_party$
- builtin$
- examples$
issues:
max-issues-per-linter: 0
max-same-issues: 0
uniq-by-line: true
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$
+185
View File
@@ -0,0 +1,185 @@
version: 2
# Two release artefacts:
#
# 1. The Traefik plugin: source-only — Traefik loads it via the Yaegi
# interpreter from the source tarball published on GitHub releases.
# 2. oidcgate: a standalone forward-auth daemon built from cmd/oidcgate.
# Shipped as both per-OS/arch binary archives AND a multi-arch Docker
# image at ghcr.io/lukaszraczylo/oidcgate, tagged to match the release.
builds:
- id: oidcgate
main: ./cmd/oidcgate
binary: oidcgate
env:
- CGO_ENABLED=0
goos:
- linux
- darwin
goarch:
- amd64
- arm64
flags:
- -trimpath
- -buildvcs=false
ldflags:
- -s -w
- -X main.version={{.Version}}
- -X main.commit={{.ShortCommit}}
- -X main.date={{.Date}}
mod_timestamp: "{{ .CommitTimestamp }}"
archives:
# Source archive for the Traefik plugin path. meta:true → no binary
# builds attached; everything comes from `files:` below.
- id: source-plugin
meta: true
formats: [tar.gz]
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
files:
- "*.go"
- "**/*.go"
- go.mod
- go.sum
- .traefik.yml
- LICENSE*
- README*
# Exclude test files and vendor from release archive
- "!**/*_test.go"
- "!vendor/**"
- "!docker/**"
- "!integration/**"
- "!regression/**"
- "!examples/**"
- "!docs/**"
- "!cmd/**"
# Per-OS/arch binary archives for the oidcgate daemon.
- id: oidcgate
ids: [oidcgate]
formats: [tar.gz]
name_template: "oidcgate_v{{ .Version }}_{{ .Os }}_{{ .Arch }}"
files:
- LICENSE*
- README*
- src: docs/OIDCGATE.md
dst: docs/
- src: examples/oidcgate.yaml
dst: examples/
# Build a Docker image per (linux, arch) combo. Tag suffixes are
# combined into a single multi-arch manifest list below via
# docker_manifests, so end users pull a single tag.
dockers:
- id: oidcgate-amd64
ids: [oidcgate]
goos: linux
goarch: amd64
image_templates:
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-amd64"
use: buildx
dockerfile: cmd/oidcgate/Dockerfile
build_flag_templates:
- "--pull"
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.title=oidcgate"
- "--label=org.opencontainers.image.description=Standalone OIDC forward-auth daemon for nginx/Caddy/Traefik/HAProxy/Envoy"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .FullCommit }}"
- "--label=org.opencontainers.image.created={{ .Date }}"
- "--label=org.opencontainers.image.source=https://github.com/lukaszraczylo/traefikoidc"
- "--label=org.opencontainers.image.url=https://github.com/lukaszraczylo/traefikoidc"
- "--label=org.opencontainers.image.documentation=https://github.com/lukaszraczylo/traefikoidc/blob/main/docs/OIDCGATE.md"
- "--label=org.opencontainers.image.licenses=MIT"
- id: oidcgate-arm64
ids: [oidcgate]
goos: linux
goarch: arm64
image_templates:
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-arm64"
use: buildx
dockerfile: cmd/oidcgate/Dockerfile
build_flag_templates:
- "--pull"
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.title=oidcgate"
- "--label=org.opencontainers.image.description=Standalone OIDC forward-auth daemon for nginx/Caddy/Traefik/HAProxy/Envoy"
- "--label=org.opencontainers.image.version={{ .Version }}"
- "--label=org.opencontainers.image.revision={{ .FullCommit }}"
- "--label=org.opencontainers.image.created={{ .Date }}"
- "--label=org.opencontainers.image.source=https://github.com/lukaszraczylo/traefikoidc"
- "--label=org.opencontainers.image.url=https://github.com/lukaszraczylo/traefikoidc"
- "--label=org.opencontainers.image.documentation=https://github.com/lukaszraczylo/traefikoidc/blob/main/docs/OIDCGATE.md"
- "--label=org.opencontainers.image.licenses=MIT"
# Multi-arch manifests — these are what users actually pull.
# Tags match the release tag (vX.Y.Z) exactly, plus a few convenience tags.
docker_manifests:
- name_template: "ghcr.io/lukaszraczylo/oidcgate:v{{ .Version }}"
image_templates:
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-amd64"
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-arm64"
- name_template: "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}"
image_templates:
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-amd64"
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-arm64"
- name_template: "ghcr.io/lukaszraczylo/oidcgate:v{{ .Major }}.{{ .Minor }}"
image_templates:
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-amd64"
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-arm64"
skip_push: auto
- name_template: "ghcr.io/lukaszraczylo/oidcgate:v{{ .Major }}"
image_templates:
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-amd64"
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-arm64"
skip_push: auto
- name_template: "ghcr.io/lukaszraczylo/oidcgate:latest"
image_templates:
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-amd64"
- "ghcr.io/lukaszraczylo/oidcgate:{{ .Version }}-arm64"
skip_push: auto
checksum:
name_template: "{{ .ProjectName }}_v{{ .Version }}_checksums.txt"
algorithm: sha256
changelog:
sort: asc
filters:
exclude:
- "^docs:"
- "^test:"
- "^Merge"
- "^WIP"
- "^chore:"
release:
github:
owner: lukaszraczylo
name: traefikoidc
name_template: "v{{ .Version }}"
draft: false
prerelease: auto
signs:
- cmd: cosign
signature: "${artifact}.sigstore.json"
args:
- sign-blob
- "--bundle=${signature}"
- "${artifact}"
- "--yes"
artifacts: checksum
output: true
# Sign the Docker images and manifests with cosign keyless.
docker_signs:
- cmd: cosign
artifacts: all
args:
- sign
- "${artifact}@${digest}"
- "--yes"
output: true
+79 -454
View File
@@ -4,476 +4,101 @@ type: middleware
import: github.com/lukaszraczylo/traefikoidc
summary: |
Middleware adding OpenID Connect (OIDC) authentication to Traefik routes.
OpenID Connect authentication middleware for Traefik. Replaces forward-auth
+ oauth2-proxy with a single plugin that auto-detects all major OIDC
providers (Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab,
generic) and OAuth 2.0 for GitHub.
This middleware replaces the need for forward-auth and oauth2-proxy when using Traefik as a reverse proxy.
It provides a complete OIDC authentication solution with features like domain restrictions,
role-based access control, token caching, and more.
Features: ID-token validation with auto-discovery, session encryption,
proactive token refresh, RBAC via roles/groups claims, domain restriction,
templated downstream headers, security headers (CSP, HSTS, CORS), rate
limiting, PKCE, opaque-token introspection (RFC 7662), back/front-channel
logout, Dynamic Client Registration (RFC 7591), and Redis-backed shared
state for multi-replica deployments.
The middleware has been tested with Auth0, Logto, Google, and other standard OIDC providers.
It includes special handling for Google's OAuth implementation to ensure compatibility.
It supports various authentication scenarios including:
- Basic authentication with customizable callback and logout URLs
- Email domain restrictions to limit access to specific organizations
- Role and group-based access control
- Public URLs that bypass authentication
- Rate limiting to prevent brute force attacks
- Custom post-logout redirect behavior
- Secure session management with encrypted cookies
- Automatic token validation and refresh
Full documentation: https://github.com/lukaszraczylo/traefikoidc
testData:
# Required parameters
providerURL: https://accounts.google.com # Base URL of the OIDC provider
clientID: 1234567890.apps.googleusercontent.com # OAuth 2.0 client identifier
clientSecret: secret # OAuth 2.0 client secret
callbackURL: /oauth2/callback # Path where the OIDC provider will redirect after authentication
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long # Key used to encrypt session data (must be at least 32 bytes)
# Required
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
# Alternative: RFC 7523 private_key_jwt client authentication (Entra ID,
# Okta, Auth0, Keycloak). Replaces clientSecret with a signed JWT assertion.
# See README "Client authentication via private key JWT".
# clientAuthMethod: private_key_jwt
# clientAssertionKeyID: my-key-2026
# clientAssertionAlg: RS256 # default; or PS256/384/512, ES256/384/512
# # File path option:
# clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
# # Or inline PEM (PKCS#8 / PKCS#1 / SEC1):
# clientAssertionPrivateKey: |
# -----BEGIN PRIVATE KEY-----
# MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDexampleexample
# -----END PRIVATE KEY-----
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
# Optional parameters with defaults
logoutURL: /oauth2/logout # Path for handling logout requests (if not provided, it will be set to callbackURL + "/logout")
postLogoutRedirectURI: /oidc/different-logout # URL to redirect to after logout (default: "/")
# Common production knobs
logoutURL: /oauth2/logout
postLogoutRedirectURI: /
forceHTTPS: true # default; only set false for plaintext HTTP local dev
logLevel: info
rateLimit: 100
scopes: # Additional scopes to append to defaults ["openid", "profile", "email"]
- roles # Result: ["openid", "profile", "email", "roles"]
allowedUserDomains: # Restricts access to specific email domains (if not provided, relies on OIDC provider)
# Access control
allowedUserDomains:
- company.com
- subsidiary.com
allowedUsers: # Restricts access to specific email addresses regardless of domain
- specific-user@company.com
- another-user@gmail.com
allowedRolesAndGroups: # Restricts access to users with specific roles or groups (if not provided, no role/group restrictions)
- guest-endpoints
allowedRolesAndGroups:
- admin
- developer
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
excludedURLs: # Lists paths that bypass authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /public
excludedURLs:
- /health
- /metrics
headers: # Custom headers to set with templated values from claims and tokens
# NOTE: If you encounter "can't evaluate field AccessToken in type bool" errors,
# you may need to escape the templates. See the headers section in configuration below.
- name: "X-User-Email"
value: "{{.Claims.email}}"
- name: "X-User-ID"
value: "{{.Claims.sub}}"
- name: "Authorization"
value: "Bearer {{.AccessToken}}"
- name: "X-User-Roles"
value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
# Advanced parameters (usually discovered automatically from provider metadata)
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
cookieDomain: "" # Explicit domain for session cookies (e.g., ".example.com" for multi-subdomain setups)
overrideScopes: false # When true, replaces default scopes instead of appending (default: false)
refreshGracePeriodSeconds: 60 # Seconds before token expiry to attempt proactive refresh (default: 60)
# --- Provider Specific Configuration Examples ---
#
# Below are example configurations tailored for specific OIDC providers.
# Uncomment and adapt the relevant section for your provider.
# Remember to replace placeholder values (like client IDs, secrets, domains)
# with your actual credentials and settings.
#
# For all providers, ensure claims like email, roles, and groups are
# configured to be included in the ID TOKEN. This plugin validates ID tokens.
# --- Keycloak Example ---
# testDataKeycloak:
# providerURL: https://your-keycloak-domain/realms/your-realm # e.g., http://localhost:8080/realms/master
# clientID: your-keycloak-client-id
# clientSecret: your-keycloak-client-secret # Store securely, e.g., urn:k8s:secret:namespace:secret-name:key
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-keycloak"
# scopes: # Default ["openid", "profile", "email"] are usually sufficient. Add others if mappers depend on them.
# - roles # Example: if you mapped Keycloak roles to a 'roles' claim in the ID token
# - groups # Example: if you mapped Keycloak groups to a 'groups' claim in the ID token
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
# - admin
# - editor
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
# # See README.md "Provider Configuration Recommendations" for Keycloak.
# --- Azure AD (Microsoft Entra ID) Example ---
# testDataAzureAD:
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0 # Replace your-tenant-id
# clientID: your-azure-ad-client-id
# clientSecret: your-azure-ad-client-secret # Store securely
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
# scopes: # Defaults ["openid", "profile", "email"] are good.
# # Azure AD may require specific scopes for certain graph API permissions if you were to use the access token,
# # but for ID token claims, defaults are often enough.
# # Group claims need to be configured in Azure AD App Registration -> Token Configuration -> Add groups claim.
# allowedUserDomains:
# - yourcompany.com
# allowedRolesAndGroups: # If you configured group claims (typically 'groups') or app roles in Azure AD
# - "group-object-id-1" # Azure AD group claims can be Object IDs by default
# - "AppRoleName"
# # See README.md "Provider Configuration Recommendations" for Azure AD.
# --- Google Workspace / Google Cloud Identity Example ---
# testDataGoogle:
# providerURL: https://accounts.google.com # This is standard for Google
# clientID: your-google-client-id.apps.googleusercontent.com
# clientSecret: your-google-client-secret # Store securely
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-google"
# scopes: # Defaults ["openid", "profile", "email"] are handled. Plugin manages Google-specifics.
# # Do NOT add 'offline_access' - plugin handles this.
# allowedUserDomains: # Useful for Google Workspace users
# - your-gsuite-domain.com
# # Google includes 'hd' (hosted domain) claim which can be used with allowedUserDomains.
# # Other claims like 'email', 'sub', 'name' are standard.
# # See README.md "Provider Configuration Recommendations" for Google.
# --- Auth0 Example ---
# testDataAuth0:
# providerURL: https://your-auth0-domain.auth0.com # Replace with your Auth0 domain
# clientID: your-auth0-client-id
# clientSecret: your-auth0-client-secret # Store securely
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-auth0"
# scopes: # Defaults ["openid", "profile", "email"]. Add custom scopes if your Auth0 Rules/Actions require them.
# - read:custom_data # Example custom scope
# allowedRolesAndGroups: # Based on claims added via Auth0 Rules or Actions (e.g. namespaced claims)
# - "https://your-app.com/roles:admin"
# - editor
# # Use Auth0 Rules or Actions to add custom claims (roles, permissions) to the ID Token.
# # Ensure postLogoutRedirectURI is in Auth0 app's "Allowed Logout URLs".
# # See README.md "Provider Configuration Recommendations" for Auth0.
# --- Generic OIDC Provider Example ---
# testDataGenericOIDC:
# providerURL: https://your-generic-oidc-provider.com/oidc # Issuer URL for your provider
# clientID: your-generic-client-id
# clientSecret: your-generic-client-secret # Store securely
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-generic"
# scopes: # Must include "openid". "profile" and "email" are common.
# - openid
# - profile
# - email
# - custom_scope_for_claims # If your provider needs specific scopes for ID token claims
# allowedRolesAndGroups:
# - user_role_from_id_token
# # Consult your provider's documentation on how to map attributes/roles/groups to ID Token claims.
# # Verify ID Token contents (e.g. jwt.io) to see available claims.
# # See README.md "Provider Configuration Recommendations" for Generic OIDC.
# Configuration documentation
configuration:
providerURL:
type: string
description: |
The base URL of the OIDC provider. This is the issuer URL that will be used to discover
OIDC endpoints like authorization, token, and JWKS URIs.
Examples:
- https://accounts.google.com
- https://login.microsoftonline.com/tenant-id/v2.0
- https://your-auth0-domain.auth0.com
- https://your-logto-instance.com/oidc
required: true
clientID:
type: string
description: |
The OAuth 2.0 client identifier obtained from your OIDC provider.
This is the public identifier for your application.
required: true
clientSecret:
type: string
description: |
The OAuth 2.0 client secret obtained from your OIDC provider.
This should be kept confidential and not exposed in client-side code.
For Kubernetes deployments, you can use the secret reference format:
urn:k8s:secret:namespace:secret-name:key
required: true
callbackURL:
type: string
description: |
The path where the OIDC provider will redirect after authentication.
This must match one of the redirect URIs configured in your OIDC provider.
The full redirect URI will be constructed as:
[scheme]://[host][callbackURL]
Example: /oauth2/callback
required: true
sessionEncryptionKey:
type: string
description: |
Key used to encrypt session data stored in cookies.
Must be at least 32 bytes long for security.
Example: potato-secret-is-at-least-32-bytes-long
required: true
logoutURL:
type: string
description: |
The path for handling logout requests.
If not provided, it will be set to callbackURL + "/logout".
Example: /oauth2/logout
required: false
postLogoutRedirectURI:
type: string
description: |
The URL to redirect to after logout.
Default: "/"
Example: /logged-out-page
required: false
# Scopes are appended to the defaults ["openid", "profile", "email"]
scopes:
type: array
description: |
Additional OAuth 2.0 scopes to append to the default scopes.
Default scopes are always included: ["openid", "profile", "email"]
User-provided scopes are appended to defaults with automatic deduplication.
For example, specifying ["roles", "custom_scope"] results in:
["openid", "profile", "email", "roles", "custom_scope"]
Include "roles" or similar scope if you need role/group information.
Note: For Google OAuth, the middleware automatically handles the
proper authentication parameters and does NOT require the "offline_access"
scope (which Google rejects as invalid). See documentation for details.
required: false
items:
type: string
logLevel:
type: string
description: |
Sets the logging verbosity.
Valid values: "debug", "info", "error"
Default: "info"
required: false
enum:
- debug
- info
- error
forceHTTPS:
type: boolean
description: |
Forces the use of HTTPS for all URLs.
This is recommended for security in production environments.
Default: true
required: false
rateLimit:
type: integer
description: |
Sets the maximum number of requests per second.
This helps prevent brute force attacks.
Default: 100
Minimum: 10
required: false
excludedURLs:
type: array
description: |
Lists paths that bypass authentication.
These paths will be accessible without OIDC authentication.
The middleware uses prefix matching, so "/public" will match
"/public", "/public/page", "/public-data", etc.
Examples: ["/health", "/metrics", "/public"]
required: false
items:
type: string
allowedUserDomains:
type: array
description: |
Restricts access to users with email addresses from specific domains.
If not provided, the middleware relies entirely on the OIDC provider
for authentication decisions.
Examples: ["company.com", "subsidiary.com"]
required: false
items:
type: string
allowedUsers:
type: array
description: |
Restricts access to specific email addresses.
If provided, only users with these exact email addresses will be allowed access,
in addition to any domain-level restrictions set by allowedUserDomains.
This provides fine-grained control over individual access and can be used
together with allowedUserDomains for flexible access control strategies.
Examples: ["user1@example.com", "admin@company.com"]
required: false
items:
type: string
allowedRolesAndGroups:
type: array
description: |
Restricts access to users with specific roles or groups.
If not provided, no role/group restrictions are applied.
The middleware checks both the "roles" and "groups" claims in the ID token.
Examples: ["admin", "developer"]
required: false
items:
type: string
revocationURL:
type: string
description: |
The endpoint for revoking tokens.
If not provided, it will be discovered from provider metadata.
Example: https://accounts.google.com/revoke
required: false
oidcEndSessionURL:
type: string
description: |
The provider's end session endpoint.
If not provided, it will be discovered from provider metadata.
Example: https://accounts.google.com/logout
required: false
enablePKCE:
type: boolean
description: |
Enables PKCE (Proof Key for Code Exchange) for the OAuth 2.0 authorization code flow.
PKCE adds an extra layer of security to protect against authorization code interception attacks.
Not all OIDC providers support PKCE, so this should only be enabled if your provider supports it.
If enabled, the middleware will generate and use a code verifier/challenge pair during authentication.
Default: false
required: false
cookieDomain:
type: string
description: |
Explicit domain for session cookies. This is important for multi-subdomain setups
and reverse proxy deployments to ensure consistent cookie handling.
When set, all session cookies will use this domain. When not set, the domain
is auto-detected from the request headers (X-Forwarded-Host or Host).
Use a leading dot for subdomain-wide cookies (e.g., ".example.com" allows
cookies to be shared between app.example.com, api.example.com, etc.).
Use a specific domain for host-only cookies (e.g., "app.example.com" restricts
cookies to that exact domain).
This setting is crucial to prevent authentication issues like "CSRF token missing
in session" errors that can occur when cookies are created with inconsistent domains.
Examples:
- ".example.com" - Allows all subdomains to share cookies
- "app.example.com" - Restricts cookies to this specific host
Default: "" (auto-detected from request headers)
required: false
overrideScopes:
type: boolean
description: |
When set to true, the scopes you provide will completely replace the default scopes
(openid, profile, email) instead of being appended to them.
This is useful when you need precise control over the scopes sent to the OIDC provider,
such as when a provider requires specific scopes or when you want to minimize the
requested permissions.
Default: false (appends user scopes to defaults)
required: false
refreshGracePeriodSeconds:
type: integer
description: |
The number of seconds before a token expires to attempt proactive refresh.
When a request is made and the access token will expire within this grace period,
the middleware will attempt to refresh the token proactively. This helps prevent
authentication interruptions for active users.
Setting this to 0 disables proactive refresh (tokens are only refreshed after expiry).
Default: 60 (1 minute before expiry)
required: false
- roles
# Templated headers forwarded to backends.
# NOTE: use quadruple braces — the YAML parser collapses {{{{ → {{ so the
# Go template engine receives the correct expression.
headers:
type: array
description: |
Custom HTTP headers to set with templated values derived from OIDC claims and tokens.
Each header has a name and a value template that can access:
- {{.Claims.field}} - Access ID token claims (e.g., email, sub, name)
- {{.AccessToken}} - The raw access token string
- {{.IdToken}} - The raw ID token string
- {{.RefreshToken}} - The raw refresh token string
- name: X-User-Email
value: "{{{{.Claims.email}}}}"
- name: Authorization
value: "Bearer {{{{.AccessToken}}}}"
Templates support Go template syntax including conditionals and iteration.
Variable names are case-sensitive - use .Claims not .claims.
# Security headers (default profile is enabled out of the box)
securityHeaders:
enabled: true
profile: default
IMPORTANT: Template Escaping
If you encounter the error "can't evaluate field AccessToken in type bool" when
starting Traefik, this means Traefik is trying to evaluate the template expressions
before passing them to the plugin. To fix this, you need to escape the templates
using one of these methods:
# Optional: Redis for multi-replica deployments. See docs/REDIS.md.
# redis:
# enabled: true
# address: redis:6379
# password: urn:k8s:secret:redis:password
# cacheMode: hybrid
1. Use YAML literal style (recommended):
headers:
- name: "Authorization"
value: |
Bearer {{.AccessToken}}
# Optional: bearer-token authentication for M2M (machine-to-machine) API
# clients. Default off. When enabled, requests presenting
# "Authorization: Bearer <jwt>" are validated against the configured OIDC
# provider (signature/issuer/audience/exp) and forwarded without creating
# a cookie session. The bearer path REJECTS ID tokens, requires a non-
# default audience, and never trusts the `email` claim as the identifier.
# See docs/BEARER_AUTH.md for the full threat model.
#
# enableBearerAuth: true # opt-in
# audience: https://api.example.com # REQUIRED when bearer is enabled
# bearerIdentifierClaim: sub # default; used as X-Forwarded-User. `email` is rejected.
# stripAuthorizationHeader: true # default; drops the raw token before forwarding
# bearerEmitWWWAuthenticate: true # default; RFC 6750 hint on 401s
# bearerOverridesCookie: false # default; cookie wins when both are present
# requireTokenIntrospection: false # opt-in; calls RFC 7662 introspection per request
# maxTokenAgeSeconds: 86400 # 24h cap on iat (rejects clock-skew/forever tokens)
# maxIdentifierLength: 256 # cap on the sanitised principal identifier
# bearerFailureThreshold: 20 # consecutive 401s/IP that trip the throttle
# bearerFailureWindowSeconds: 60 # rolling window over which 401s are counted
# bearerFailurePenaltySeconds: 60 # 429 + Retry-After duration after threshold trips
2. Use single quotes:
headers:
- name: "Authorization"
value: 'Bearer {{.AccessToken}}'
3. For inline double quotes, escape the braces:
headers:
- name: "Authorization"
value: "Bearer {{"{{.AccessToken}}"}}"
Examples:
- name: "X-User-Email", value: "{{.Claims.email}}"
- name: "Authorization", value: "Bearer {{.AccessToken}}"
- name: "X-User-Roles", value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
required: false
items:
type: object
properties:
name:
type: string
description: The HTTP header name to set
value:
type: string
description: Template string for the header value
+359 -859
View File
File diff suppressed because it is too large Load Diff
+49
View File
@@ -0,0 +1,49 @@
# Security Fix: Integer Overflow Protection in Cache Serialization
## Summary
Fixed **High severity** integer overflow vulnerability identified by GitHub Advanced Security in PR #117.
## Vulnerability
**Locations**: `universal_cache.go` lines 789 and 811
- `result := make([]byte, len(bytes)+1)` - Raw bytes path
- `result := make([]byte, len(jsonData)+1)` - JSON encoding path
**Risk**: Potential integer overflow when allocating memory for very large cache entries.
## Fix Applied
1. **Added size limit constant**:
```go
maxCacheEntrySize = 64 * 1024 * 1024 // 64 MiB
```
2. **Size validation before allocation**:
- Validates entry size doesn't exceed limit
- Validates adding marker byte won't overflow
- Returns descriptive error messages
3. **Comprehensive test coverage**:
- Oversized byte slices (>64 MiB)
- Exact max size edge case
- Safe sizes (normal operation)
- Large JSON data structures
## Verification
✅ All tests pass with race detection
✅ No security issues (golangci-lint, gosec)
✅ 76.3% test coverage maintained
## Impact
- No breaking changes
- Negligible performance overhead
- Prevents potential buffer overflows
- Predictable memory usage
---
**Date**: January 8, 2026
**Severity**: High → Resolved
-308
View File
@@ -1,308 +0,0 @@
# Test Execution Guide
This guide explains how to run tests efficiently with the new test categorization and optimization system.
## Quick Start
### Fast Development Testing (Default - Target: < 30 seconds)
```bash
# Run quick smoke tests only
go test ./...
# Or explicitly run in short mode
go test ./... -short
```
### Extended Testing (Target: 2-5 minutes)
```bash
# Enable extended tests with more iterations and concurrency
RUN_EXTENDED_TESTS=1 go test ./...
# Or use the flag equivalent (if using test runner that supports it)
go test ./... -extended
```
### Long-Running Performance Tests (Target: 5-15 minutes)
```bash
# Enable comprehensive performance and stress tests
RUN_LONG_TESTS=1 go test ./...
```
### Full Stress Testing (Target: 10-30 minutes)
```bash
# Enable all stress tests with maximum parameters
RUN_STRESS_TESTS=1 go test ./...
```
## Test Categories
### 1. Quick Tests (Default)
- **Purpose**: Fast feedback during development
- **Duration**: < 30 seconds total
- **Features**:
- Basic functionality verification
- Limited iterations (1-3)
- Small data sets
- Minimal concurrency
- Essential memory leak checks
**Configuration**:
- Max Iterations: 3
- Max Concurrency: 5
- Memory Threshold: 2.0 MB
- Cache Size: 50
- Timeout: 10 seconds
### 2. Extended Tests
- **Purpose**: Comprehensive testing before commits
- **Duration**: 2-5 minutes
- **Features**:
- Increased test coverage
- More iterations (5-10)
- Medium concurrency tests
- Enhanced memory leak detection
**Configuration**:
- Max Iterations: 10
- Max Concurrency: 20
- Memory Threshold: 10.0 MB
- Cache Size: 200
- Timeout: 30 seconds
### 3. Long Tests
- **Purpose**: Performance validation and stress testing
- **Duration**: 5-15 minutes
- **Features**:
- High iteration counts (50-100)
- High concurrency scenarios
- Large data sets
- Comprehensive memory testing
**Configuration**:
- Max Iterations: 100
- Max Concurrency: 50
- Memory Threshold: 50.0 MB
- Cache Size: 1000
- Timeout: 60 seconds
### 4. Stress Tests
- **Purpose**: Maximum load testing and edge case validation
- **Duration**: 10-30 minutes
- **Features**:
- Extreme iteration counts (100-500)
- Maximum concurrency (100+)
- Large memory allocations
- Edge case combinations
**Configuration**:
- Max Iterations: 500
- Max Concurrency: 100
- Memory Threshold: 100.0 MB
- Cache Size: 2000
- Timeout: 120 seconds
## Environment Variables
### Test Execution Control
```bash
# Enable specific test types
export RUN_EXTENDED_TESTS=1 # Enable extended tests
export RUN_LONG_TESTS=1 # Enable long-running tests
export RUN_STRESS_TESTS=1 # Enable stress tests
# Disable specific features
export DISABLE_LEAK_DETECTION=1 # Skip memory leak detection
```
### Parameter Customization
```bash
# Customize concurrency limits
export TEST_MAX_CONCURRENCY=10 # Override max concurrent operations
# Customize iteration limits
export TEST_MAX_ITERATIONS=50 # Override max test iterations
# Customize memory thresholds
export TEST_MEMORY_THRESHOLD_MB=25.5 # Override memory growth limit (in MB)
```
## Test-Specific Behavior
### Memory Leak Tests
- **Quick Mode**: 1-3 iterations, small data sets, strict memory limits
- **Extended Mode**: 5-10 iterations, medium data sets, relaxed limits
- **Long Mode**: 50-100 iterations, large data sets, performance focus
- **Stress Mode**: 100-500 iterations, maximum data sets, stress focus
### Concurrency Tests
- **Quick Mode**: 2-5 concurrent operations, basic race detection
- **Extended Mode**: 10-20 concurrent operations, moderate stress
- **Long Mode**: 20-50 concurrent operations, high contention
- **Stress Mode**: 50-100+ concurrent operations, maximum stress
### Cache Tests
- **Quick Mode**: Small caches (50 items), basic operations
- **Extended Mode**: Medium caches (200 items), varied operations
- **Long Mode**: Large caches (1000 items), performance testing
- **Stress Mode**: Very large caches (2000+ items), stress testing
## Integration with CI/CD
### GitHub Actions Example
```yaml
# Quick tests for every push/PR
- name: Quick Tests
run: go test ./... -short
# Extended tests for main branch
- name: Extended Tests
if: github.ref == 'refs/heads/main'
run: RUN_EXTENDED_TESTS=1 go test ./...
# Nightly comprehensive testing
- name: Nightly Stress Tests
if: github.event_name == 'schedule'
run: RUN_STRESS_TESTS=1 go test ./...
```
### Local Development Workflow
```bash
# During active development
go test ./... -short
# Before committing
RUN_EXTENDED_TESTS=1 go test ./...
# Before major releases
RUN_LONG_TESTS=1 go test ./...
# Performance validation
RUN_STRESS_TESTS=1 go test ./...
```
## Performance Optimization Features
### Dynamic Test Scaling
The test system automatically adjusts parameters based on:
- Test mode (quick/extended/long/stress)
- Available resources
- Environment variables
- Previous test performance
### Memory Management
- **Garbage Collection**: Forced GC between test iterations
- **Memory Monitoring**: Real-time memory growth tracking
- **Leak Detection**: Goroutine and memory leak prevention
- **Resource Cleanup**: Automatic cleanup of test resources
### Timeout Management
- **Adaptive Timeouts**: Timeouts scale with test complexity
- **Graceful Degradation**: Tests adapt to slower environments
- **Early Termination**: Failed tests terminate quickly
## Troubleshooting
### Tests Taking Too Long
```bash
# Check if running in extended mode accidentally
echo $RUN_EXTENDED_TESTS $RUN_LONG_TESTS
# Force quick mode
unset RUN_EXTENDED_TESTS RUN_LONG_TESTS RUN_STRESS_TESTS
go test ./... -short
```
### Memory Issues
```bash
# Reduce memory limits for constrained environments
export TEST_MEMORY_THRESHOLD_MB=5.0
export TEST_MAX_CONCURRENCY=2
go test ./...
```
### Concurrency Issues
```bash
# Reduce concurrency for slower systems
export TEST_MAX_CONCURRENCY=5
export TEST_MAX_ITERATIONS=10
go test ./...
```
### Skip Specific Test Types
```bash
# Skip memory leak detection if problematic
export DISABLE_LEAK_DETECTION=1
go test ./...
```
## Benchmarking
### Running Benchmarks
```bash
# Quick benchmarks
go test -bench=. -short
# Extended benchmarks
RUN_EXTENDED_TESTS=1 go test -bench=.
# Memory profiling
go test -bench=. -memprofile=mem.prof
go tool pprof mem.prof
```
### Benchmark Categories
- **Basic Operations**: Set/Get performance
- **Concurrency**: Multi-threaded performance
- **Memory**: Allocation and cleanup performance
- **Cache**: Eviction and cleanup performance
## Best Practices
### For Developers
1. Always run quick tests during development (`go test ./... -short`)
2. Run extended tests before committing (`RUN_EXTENDED_TESTS=1 go test ./...`)
3. Use appropriate test categories for your use case
4. Monitor test execution time and adjust if needed
### For CI/CD
1. Use quick tests for fast feedback on PRs
2. Use extended tests for main branch validation
3. Use long tests for release validation
4. Use stress tests for nightly/weekly validation
### For Performance Testing
1. Use consistent environment variables
2. Run tests multiple times for statistical significance
3. Monitor both execution time and resource usage
4. Use profiling tools for detailed analysis
## Examples
### Daily Development
```bash
# Fast tests while coding
go test ./... -short
# Before git commit
RUN_EXTENDED_TESTS=1 go test ./...
```
### Release Testing
```bash
# Comprehensive validation
RUN_LONG_TESTS=1 go test ./...
# Stress testing
RUN_STRESS_TESTS=1 go test ./...
```
### Custom Configuration
```bash
# Custom limits for specific environment
export TEST_MAX_CONCURRENCY=8
export TEST_MAX_ITERATIONS=25
export TEST_MEMORY_THRESHOLD_MB=15.0
RUN_EXTENDED_TESTS=1 go test ./...
```
This test system provides flexible, scalable test execution that adapts to your development workflow and infrastructure constraints while maintaining comprehensive test coverage.
+1518
View File
File diff suppressed because it is too large Load Diff
-360
View File
@@ -1,360 +0,0 @@
// Package auth provides authentication-related functionality for the OIDC middleware.
package auth
import (
"fmt"
"net"
"net/http"
"net/url"
"strings"
"github.com/google/uuid"
)
// AuthHandler provides core authentication functionality for OIDC flows
type AuthHandler struct {
logger Logger
enablePKCE bool
isGoogleProv func() bool
isAzureProv func() bool
clientID string
authURL string
issuerURL string
scopes []string
overrideScopes bool
}
// Logger interface for dependency injection
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// NewAuthHandler creates a new AuthHandler instance
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool) *AuthHandler {
return &AuthHandler{
logger: logger,
enablePKCE: enablePKCE,
isGoogleProv: isGoogleProv,
isAzureProv: isAzureProv,
clientID: clientID,
authURL: authURL,
issuerURL: issuerURL,
scopes: scopes,
overrideScopes: overrideScopes,
}
}
// InitiateAuthentication initiates the OIDC authentication flow.
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
// stores authentication state, and redirects the user to the OIDC provider.
func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
session SessionData, redirectURL string,
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
const maxRedirects = 5
redirectCount := session.GetRedirectCount()
if redirectCount >= maxRedirects {
h.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
session.ResetRedirectCount()
http.Error(rw, "Authentication failed: Too many redirects", http.StatusLoopDetected)
return
}
session.IncrementRedirectCount()
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
h.logger.Errorf("Failed to generate nonce: %v", err)
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Generate PKCE code verifier and challenge if PKCE is enabled
var codeVerifier, codeChallenge string
if h.enablePKCE {
codeVerifier, err = generateCodeVerifier()
if err != nil {
h.logger.Errorf("Failed to generate code verifier: %v", err)
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
codeChallenge, err = deriveCodeChallenge()
if err != nil {
h.logger.Errorf("Failed to generate code challenge: %v", err)
http.Error(rw, "Failed to generate code challenge", http.StatusInternalServerError)
return
}
h.logger.Debugf("PKCE enabled, generated code challenge")
}
session.SetAuthenticated(false)
session.SetEmail("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
session.SetNonce("")
session.SetCodeVerifier("")
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
if h.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(req.URL.RequestURI())
h.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
session.MarkDirty()
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
h.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
csrfToken, nonce)
authURL := h.BuildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
h.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// BuildAuthURL constructs the OIDC provider authorization URL.
// It builds the URL with all necessary parameters including client_id, scopes,
// PKCE parameters, and provider-specific parameters for Google and Azure.
func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", h.clientID)
params.Set("response_type", "code")
params.Set("redirect_uri", redirectURL)
params.Set("state", state)
params.Set("nonce", nonce)
if h.enablePKCE && codeChallenge != "" {
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
}
scopes := make([]string, len(h.scopes))
copy(scopes, h.scopes)
if h.isGoogleProv() {
params.Set("access_type", "offline")
h.logger.Debugf("Google OIDC provider detected, added access_type=offline for refresh tokens")
params.Set("prompt", "consent")
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
} else if h.isAzureProv() {
params.Set("response_mode", "query")
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
}
} else {
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
}
} else {
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
}
} else {
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
}
}
if len(scopes) > 0 {
finalScopeString := strings.Join(scopes, " ")
params.Set("scope", finalScopeString)
h.logger.Debugf("AuthHandler.BuildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
}
return h.buildURLWithParams(h.authURL, params)
}
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
// It handles both relative and absolute URLs, validates URL security,
// and properly encodes query parameters.
func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) string {
if baseURL != "" {
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
if err := h.validateURL(baseURL); err != nil {
h.logger.Errorf("URL validation failed for %s: %v", baseURL, err)
return ""
}
}
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
issuerURLParsed, err := url.Parse(h.issuerURL)
if err != nil {
h.logger.Errorf("Could not parse issuerURL: %s. Error: %v", h.issuerURL, err)
return ""
}
baseURLParsed, err := url.Parse(baseURL)
if err != nil {
h.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err)
return ""
}
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
if err := h.validateURL(resolvedURL.String()); err != nil {
h.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err)
return ""
}
resolvedURL.RawQuery = params.Encode()
return resolvedURL.String()
}
u, err := url.Parse(baseURL)
if err != nil {
h.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
return ""
}
if err := h.validateParsedURL(u); err != nil {
h.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err)
return ""
}
u.RawQuery = params.Encode()
return u.String()
}
// validateURL performs security validation on URLs to prevent SSRF attacks.
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
func (h *AuthHandler) validateURL(urlStr string) error {
if urlStr == "" {
return fmt.Errorf("empty URL")
}
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format: %w", err)
}
return h.validateParsedURL(u)
}
// validateParsedURL validates a parsed URL structure for security.
// It checks schemes, hosts, and paths to prevent malicious URLs.
func (h *AuthHandler) validateParsedURL(u *url.URL) error {
allowedSchemes := map[string]bool{
"https": true,
"http": true,
}
if !allowedSchemes[u.Scheme] {
return fmt.Errorf("disallowed URL scheme: %s", u.Scheme)
}
if u.Scheme == "http" {
h.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String())
}
if u.Host == "" {
return fmt.Errorf("missing host in URL")
}
if err := h.validateHost(u.Host); err != nil {
return fmt.Errorf("invalid host: %w", err)
}
if strings.Contains(u.Path, "..") {
return fmt.Errorf("path traversal detected in URL path")
}
return nil
}
// validateHost validates a hostname for security and reachability.
// It prevents access to private networks and localhost addresses.
func (h *AuthHandler) validateHost(host string) error {
if host == "" {
return fmt.Errorf("empty host")
}
// Strip port if present
if strings.Contains(host, ":") {
var err error
host, _, err = net.SplitHostPort(host)
if err != nil {
return fmt.Errorf("invalid host:port format: %w", err)
}
}
// Check for localhost variations
localhostVariations := []string{
"localhost", "127.0.0.1", "::1", "0.0.0.0",
}
for _, localhost := range localhostVariations {
if strings.EqualFold(host, localhost) {
return fmt.Errorf("localhost access not allowed: %s", host)
}
}
// Try to parse as IP address
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() {
return fmt.Errorf("loopback IP not allowed: %s", host)
}
if ip.IsPrivate() {
return fmt.Errorf("private IP not allowed: %s", host)
}
if ip.IsLinkLocalUnicast() {
return fmt.Errorf("link-local IP not allowed: %s", host)
}
if ip.IsMulticast() {
return fmt.Errorf("multicast IP not allowed: %s", host)
}
}
return nil
}
// SessionData interface for dependency injection
type SessionData interface {
GetRedirectCount() int
ResetRedirectCount()
IncrementRedirectCount()
SetAuthenticated(bool)
SetEmail(string)
SetAccessToken(string)
SetRefreshToken(string)
SetIDToken(string)
SetNonce(string)
SetCodeVerifier(string)
SetCSRF(string)
SetIncomingPath(string)
MarkDirty()
Save(req *http.Request, rw http.ResponseWriter) error
}
+391
View File
@@ -0,0 +1,391 @@
package traefikoidc
import (
"fmt"
"net/http"
"strings"
"time"
)
// validateRedirectCount checks if redirect limit is exceeded and handles the error
func (t *TraefikOidc) validateRedirectCount(session *SessionData, rw http.ResponseWriter, req *http.Request) error {
const maxRedirects = 5
redirectCount := session.GetRedirectCount()
if redirectCount >= maxRedirects {
t.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
session.ResetRedirectCount()
t.sendErrorResponse(rw, req, "Authentication failed: Too many redirects", http.StatusLoopDetected)
return fmt.Errorf("redirect limit exceeded")
}
session.IncrementRedirectCount()
return nil
}
// generatePKCEParameters generates PKCE code verifier and challenge if PKCE is enabled
func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
if !t.enablePKCE {
return "", "", nil
}
codeVerifier, err := generateCodeVerifier()
if err != nil {
return "", "", fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := deriveCodeChallenge(codeVerifier)
t.logger.Debugf("PKCE enabled, generated code challenge")
return codeVerifier, codeChallenge, nil
}
// prepareSessionForAuthentication clears existing session data and sets new authentication state
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
// Clear all existing session data
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
session.SetNonce("")
session.SetCodeVerifier("")
// Set new authentication state
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
if t.enablePKCE && codeVerifier != "" {
session.SetCodeVerifier(codeVerifier)
}
session.SetIncomingPath(incomingPath)
t.logger.Debugf("Storing incoming path: %s", incomingPath)
}
// defaultInitiateAuthentication initiates the OIDC authentication flow.
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
// stores authentication state, and redirects the user to the OIDC provider.
// Parameters:
// - rw: The HTTP response writer.
// - req: The HTTP request initiating authentication.
// - session: The session data to prepare for authentication.
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.logger.Debugf("Initiating new OIDC authentication flow for request: %s", t.originalRequestURI(req))
// Check and handle redirect limits
if err := t.validateRedirectCount(session, rw, req); err != nil {
return
}
csrfToken, err := newUUIDv4()
if err != nil {
t.logger.Errorf("Failed to generate CSRF token: %v", err)
http.Error(rw, "Failed to generate CSRF token", http.StatusInternalServerError)
return
}
nonce, err := generateNonce()
if err != nil {
t.logger.Errorf("Failed to generate nonce: %v", err)
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Generate PKCE parameters if enabled
codeVerifier, codeChallenge, err := t.generatePKCEParameters()
if err != nil {
t.logger.Errorf("Failed to generate PKCE parameters: %v", err)
http.Error(rw, "Failed to generate PKCE parameters", http.StatusInternalServerError)
return
}
// Clear existing session data and set new authentication state
t.prepareSessionForAuthentication(session, csrfToken, nonce, codeVerifier, t.originalRequestURI(req))
session.MarkDirty()
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
t.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
csrfToken, nonce)
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// handleCallback processes the OIDC callback after user authentication.
// It validates state/CSRF tokens, exchanges authorization code for tokens,
// verifies the received tokens, extracts claims, and establishes the session.
// Parameters:
// - rw: The HTTP response writer.
// - req: The callback request containing authorization code and state.
// - redirectURL: The fully qualified callback URL (used in the token exchange request).
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error during callback: %v", err)
t.sendErrorResponse(rw, req, "Session error during callback", http.StatusInternalServerError)
return
}
defer session.returnToPoolSafely()
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error")
}
t.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
t.sendErrorResponse(rw, req, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
session.GetAuthenticated(), req.URL.String())
cookie, err := req.Cookie("_oidc_raczylo_m")
if err != nil {
t.logger.Errorf("Main session cookie not found in request: %v", err)
} else {
t.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
}
t.sendErrorResponse(rw, req, "CSRF token missing in session", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session during callback")
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
t.sendErrorResponse(rw, req, "No authorization code received in callback", http.StatusBadRequest)
return
}
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
if err = t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
if userIdentifier == "" {
// Try "sub" as fallback since it's required by OIDC spec
if t.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
return
}
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
}
// Validate user authorization
if !t.isAllowedUser(userIdentifier) {
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
return
}
if err := session.SetAuthenticated(true); err != nil {
t.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetUserIdentifier(userIdentifier)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
session.ResetRedirectCount()
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
session.SetIncomingPath("")
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after callback: %v", err)
t.sendErrorResponse(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
return
}
t.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// handleExpiredToken handles requests with expired or invalid tokens.
// It clears the session data and initiates a new authentication flow.
// Parameters:
// - rw: The HTTP response writer.
// - req: The HTTP request with expired token.
// - session: The session data to clear.
// - redirectURL: The callback URL to be used in the new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication on expired token
session.SetIDToken("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetUserIdentifier("")
// Clear CSRF tokens to prevent replay attacks
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
// Reset redirect count to prevent loops when handling expired tokens
session.ResetRedirectCount()
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session during expired token handling: %v", err)
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// isUserAuthenticated determines the authentication status and refresh requirements.
// It delegates to provider-specific validation methods that handle different token types
// and expiration behaviors.
// Parameters:
// - session: The session data containing authentication tokens.
//
// Returns:
// - authenticated (bool): True if the user has valid tokens.
// - needsRefresh (bool): True if tokens are valid but nearing expiration.
// - expired (bool): True if the session is unauthenticated, the token is missing,
// or the token verification failed for reasons other than nearing/actual expiration.
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if t.isAzureProvider() {
return t.validateAzureTokens(session)
} else if t.isGoogleProvider() {
return t.validateGoogleTokens(session)
}
// Auth0 and other providers can now use standard validation
// which handles opaque tokens generically
return t.validateStandardTokens(session)
}
// isAjaxRequest determines if this is an AJAX request that should receive 401 instead of redirect
func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
xhr := req.Header.Get("X-Requested-With")
contentType := req.Header.Get("Content-Type")
accept := req.Header.Get("Accept")
return xhr == "XMLHttpRequest" ||
strings.Contains(contentType, "application/json") ||
strings.Contains(accept, "application/json")
}
// isNonNavigationRequest reports whether the request is a browser
// sub-resource (script, image, stylesheet, fetch, serviceWorker) rather than
// a top-level HTML navigation. Non-navigation requests MUST NOT trigger an
// OIDC redirect flow: several sub-resource loads happening in parallel would
// each call defaultInitiateAuthentication, each overwriting the session's
// CSRF/nonce, breaking the eventual callback (issue #129).
//
// Detection prefers Sec-Fetch-Mode, which all modern browsers send
// (Chrome/Edge/Firefox/Safari). For older or non-browser clients we fall
// back to Accept: if Accept is present and does not list text/html, treat
// it as a sub-resource. An empty/missing Accept is assumed to be navigation
// (safer to redirect than 401 on an ambiguous request).
func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
if mode := req.Header.Get("Sec-Fetch-Mode"); mode != "" {
return mode != "navigate"
}
accept := req.Header.Get("Accept")
if accept == "" || accept == "*/*" {
return false
}
return !strings.Contains(accept, "text/html")
}
// isRefreshTokenExpired checks whether the stored refresh token is likely
// past its useful lifetime, using the cookie-side issued_at timestamp set by
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
// MaxRefreshTokenAgeSeconds; 0 disables the check).
//
// The point of this check is to short-circuit the refresh path BEFORE the
// thundering herd hits the IdP for a token the provider has almost certainly
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
// style polling clients from looping on invalid_grant after a long pause.
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
if t == nil || session == nil {
return false
}
if t.maxRefreshTokenAge <= 0 {
return false
}
issuedAt := session.GetRefreshTokenIssuedAt()
if issuedAt.IsZero() {
// No timestamp recorded (legacy session pre-dating the issued_at
// field). Don't force a re-auth - attempt refresh once and let the
// IdP be the source of truth.
return false
}
return time.Since(issuedAt) > t.maxRefreshTokenAge
}
File diff suppressed because it is too large Load Diff
+101
View File
@@ -0,0 +1,101 @@
package traefikoidc
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestGeneratePKCEParameters tests the generatePKCEParameters method
func TestGeneratePKCEParameters(t *testing.T) {
t.Run("PKCE enabled - successful generation", func(t *testing.T) {
// Create a TraefikOidc instance with PKCE enabled
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
assert.NotEmpty(t, verifier, "code verifier should not be empty when PKCE is enabled")
assert.NotEmpty(t, challenge, "code challenge should not be empty when PKCE is enabled")
// Verify the challenge is derived from the verifier
expectedChallenge := deriveCodeChallenge(verifier)
assert.Equal(t, expectedChallenge, challenge, "challenge should match derived challenge from verifier")
})
t.Run("PKCE disabled - returns empty strings", func(t *testing.T) {
// Create a TraefikOidc instance with PKCE disabled
plugin := &TraefikOidc{
enablePKCE: false,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
assert.Empty(t, verifier, "code verifier should be empty when PKCE is disabled")
assert.Empty(t, challenge, "code challenge should be empty when PKCE is disabled")
})
t.Run("PKCE enabled - generates different values each time", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier1, challenge1, err1 := plugin.generatePKCEParameters()
require.NoError(t, err1)
verifier2, challenge2, err2 := plugin.generatePKCEParameters()
require.NoError(t, err2)
assert.NotEqual(t, verifier1, verifier2, "verifiers should be different")
assert.NotEqual(t, challenge1, challenge2, "challenges should be different")
})
t.Run("PKCE enabled - verifier and challenge relationship", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// The challenge should always be derivable from the verifier
recalculatedChallenge := deriveCodeChallenge(verifier)
assert.Equal(t, challenge, recalculatedChallenge,
"challenge should always match the SHA256 hash of verifier")
})
t.Run("PKCE enabled - verifier meets RFC 7636 requirements", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
verifier, _, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// RFC 7636 requires verifier to be 43-128 characters
assert.GreaterOrEqual(t, len(verifier), 43, "verifier should be at least 43 characters")
assert.LessOrEqual(t, len(verifier), 128, "verifier should be at most 128 characters")
})
t.Run("PKCE enabled - challenge meets RFC 7636 requirements", func(t *testing.T) {
plugin := &TraefikOidc{
enablePKCE: true,
logger: NewLogger("debug"),
}
_, challenge, err := plugin.generatePKCEParameters()
require.NoError(t, err)
// SHA256 hash base64 encoded should be 43 characters
assert.Equal(t, 43, len(challenge), "S256 challenge should be exactly 43 characters")
})
}
+32 -28
View File
@@ -173,7 +173,7 @@ func (bt *BackgroundTask) run() {
if bt.logger != nil {
if !isTestMode() {
bt.logger.Info("Starting background task: %s", bt.name)
bt.logger.Debug("Starting background task: %s", bt.name)
}
}
@@ -182,7 +182,7 @@ func (bt *BackgroundTask) run() {
case <-bt.stopChan:
if bt.logger != nil {
if !isTestMode() {
bt.logger.Info("Stopping background task: %s (before initial execution)", bt.name)
bt.logger.Debug("Stopping background task: %s (before initial execution)", bt.name)
}
}
return
@@ -201,7 +201,7 @@ func (bt *BackgroundTask) run() {
case <-bt.stopChan:
if bt.logger != nil {
if !isTestMode() {
bt.logger.Info("Stopping background task: %s (during periodic execution)", bt.name)
bt.logger.Debug("Stopping background task: %s (during periodic execution)", bt.name)
}
}
return
@@ -211,7 +211,7 @@ func (bt *BackgroundTask) run() {
case <-bt.stopChan:
if bt.logger != nil {
if !isTestMode() {
bt.logger.Info("Stopping background task: %s (direct stop signal)", bt.name)
bt.logger.Debug("Stopping background task: %s (direct stop signal)", bt.name)
}
}
return
@@ -222,17 +222,16 @@ func (bt *BackgroundTask) run() {
// TaskCircuitBreaker implements circuit breaker pattern for background task creation
// It limits concurrent task execution and tracks failures to prevent system overload
type TaskCircuitBreaker struct {
state int32 // CircuitBreakerState
failureCount int32
lastFailureTime int64 // Unix timestamp
failureThreshold int32
timeout time.Duration
logger *Logger
// Concurrency limiting
concurrentTasks int32 // Current number of running tasks
maxConcurrent int32 // Maximum concurrent tasks allowed
activeTasks map[string]struct{} // Track active task names
tasksMu sync.RWMutex // Separate mutex for task tracking
activeTasks map[string]struct{}
lastFailureTime int64
timeout time.Duration
tasksMu sync.RWMutex
state int32
failureCount int32
failureThreshold int32
concurrentTasks int32
maxConcurrent int32
}
// NewTaskCircuitBreaker creates a new circuit breaker for background tasks
@@ -266,18 +265,21 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
max := atomic.LoadInt32(&cb.maxConcurrent)
// For cleanup tasks, be more restrictive (singleton-like behavior)
// However, allow distinct realm-specific tasks (e.g., singleton-metadata-refresh-abc123 vs singleton-metadata-refresh-def456)
if strings.Contains(taskName, "cleanup") || strings.Contains(taskName, "singleton") {
cb.tasksMu.RLock()
hasCleanupTask := false
hasSameTask := false
for activeTask := range cb.activeTasks {
if strings.Contains(activeTask, "cleanup") || strings.Contains(activeTask, "singleton") {
hasCleanupTask = true
// Only block if the EXACT same task is already running
// This allows realm-specific tasks like singleton-metadata-refresh-{hash} to run concurrently
if activeTask == taskName {
hasSameTask = true
break
}
}
cb.tasksMu.RUnlock()
if hasCleanupTask {
if hasSameTask {
return fmt.Errorf("cleanup/singleton task already running: %s", taskName)
}
}
@@ -315,7 +317,7 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
if time.Now().Unix()-lastFailure > int64(cb.timeout.Seconds()) {
atomic.StoreInt32(&cb.state, int32(CircuitBreakerHalfOpen))
if cb.logger != nil {
cb.logger.Info("Circuit breaker transitioning to half-open for task: %s", taskName)
cb.logger.Debug("Circuit breaker transitioning to half-open for task: %s", taskName)
}
return nil
}
@@ -377,9 +379,9 @@ func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
// TaskRegistry maintains a registry of all active background tasks to prevent duplicates
type TaskRegistry struct {
tasks map[string]*BackgroundTask
mu sync.RWMutex
cb *TaskCircuitBreaker
logger *Logger
mu sync.RWMutex
}
// GlobalTaskRegistry is the singleton instance for managing all background tasks
@@ -467,7 +469,7 @@ func (tr *TaskRegistry) RegisterTask(name string, task *BackgroundTask) error {
tr.cb.OnTaskSuccess(name)
if tr.logger != nil {
tr.logger.Info("Registered background task: %s", name)
tr.logger.Debug("Registered background task: %s", name)
}
return nil
@@ -483,7 +485,7 @@ func (tr *TaskRegistry) UnregisterTask(name string) {
delete(tr.tasks, name)
if tr.logger != nil {
tr.logger.Info("Unregistered background task: %s", name)
tr.logger.Debug("Unregistered background task: %s", name)
}
}
}
@@ -513,7 +515,7 @@ func (tr *TaskRegistry) StopAllTasks() {
for name, task := range tasksCopy {
task.Stop()
if tr.logger != nil {
tr.logger.Info("Stopped background task during shutdown: %s", name)
tr.logger.Debug("Stopped background task during shutdown: %s", name)
}
}
}
@@ -538,7 +540,7 @@ func (tr *TaskRegistry) CreateSingletonTask(name string, interval time.Duration,
// Start the task if not already running
if !rm.IsTaskRunning(name) {
rm.StartBackgroundTask(name)
_ = rm.StartBackgroundTask(name) // Safe to ignore: task registration succeeded, start is best-effort
}
// Get the task from resource manager's internal registry
@@ -597,8 +599,9 @@ func GetGlobalTaskMemoryMonitor(logger *Logger) *TaskMemoryMonitor {
return globalTaskMemoryMonitor
}
// NewTaskMemoryMonitor creates a new memory monitor for task registry
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior
// NewTaskMemoryMonitor creates a new memory monitor for task registry.
//
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior.
func NewTaskMemoryMonitor(logger *Logger, registry *TaskRegistry) *TaskMemoryMonitor {
return GetGlobalTaskMemoryMonitor(logger)
}
@@ -641,7 +644,7 @@ func (mm *TaskMemoryMonitor) Start(interval time.Duration) error {
mm.started = true
if mm.logger != nil && !isTestMode() {
mm.logger.Info("Started global task memory monitoring with %v interval", interval)
mm.logger.Debug("Started global task memory monitoring with %v interval", interval)
}
return nil
@@ -710,7 +713,7 @@ func (mm *TaskMemoryMonitor) checkForMemoryIssues(stats TaskMemoryStats) {
// Check for goroutine leaks (arbitrary threshold)
if stats.Goroutines > 100 {
mm.logger.Infof("High goroutine count detected: %d", stats.Goroutines)
mm.logger.Debugf("High goroutine count detected: %d", stats.Goroutines)
}
// Check for heap growth without corresponding GC activity
@@ -787,6 +790,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
}
if mm.logger != nil {
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
}
+224
View File
@@ -0,0 +1,224 @@
package traefikoidc
import (
"errors"
"sync"
"testing"
"time"
)
// globalRegistryMutex protects only the global registry operations
var globalRegistryMutex sync.Mutex
// TestTaskCircuitBreakerOnTaskFailure tests the OnTaskFailure method
func TestTaskCircuitBreakerOnTaskFailure(t *testing.T) {
logger := NewLogger("debug") // Create a real logger
cb := NewTaskCircuitBreaker(3, time.Minute, logger)
// Test failure doesn't trigger open state before threshold
cb.OnTaskFailure("test-task", errors.New("test error"))
if err := cb.CanCreateTask("test-task"); err != nil {
t.Error("Circuit breaker should allow task creation after 1 failure (threshold: 3)")
}
// Test failure count reaches threshold and opens circuit
cb.OnTaskFailure("test-task", errors.New("test error 2"))
cb.OnTaskFailure("test-task", errors.New("test error 3"))
if err := cb.CanCreateTask("test-task"); err == nil {
t.Error("Circuit breaker should prevent task creation after reaching failure threshold")
}
}
// TestResetGlobalTaskRegistry tests the reset functionality
func TestResetGlobalTaskRegistry(t *testing.T) {
globalRegistryMutex.Lock()
defer globalRegistryMutex.Unlock()
// Get the global registry first
registry := GetGlobalTaskRegistry()
// Create and register a dummy task
logger := NewLogger("debug")
task := NewBackgroundTask("test-task", time.Second, func() {
// Do nothing
}, logger)
registry.RegisterTask("test-task", task)
// Verify task is registered
if registry.GetTaskCount() == 0 {
t.Error("Expected task to be registered")
}
// Reset the registry
ResetGlobalTaskRegistry()
// Get registry again and verify it's empty
newRegistry := GetGlobalTaskRegistry()
if newRegistry.GetTaskCount() != 0 {
t.Error("Expected registry to be empty after reset")
}
}
// TestGetTask tests the GetTask method
func TestGetTask(t *testing.T) {
globalRegistryMutex.Lock()
defer globalRegistryMutex.Unlock()
// Reset registry to ensure clean state
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
// Test getting non-existent task
task, exists := registry.GetTask("non-existent")
if task != nil || exists {
t.Error("Expected nil and false for non-existent task")
}
// Create and register a task
logger := NewLogger("debug")
newTask := NewBackgroundTask("test-task", time.Second, func() {
// Do nothing
}, logger)
registry.RegisterTask("test-task", newTask)
// Test getting existing task
retrievedTask, exists := registry.GetTask("test-task")
if retrievedTask == nil || !exists {
t.Error("Expected to retrieve registered task")
return
}
if retrievedTask.name != "test-task" {
t.Errorf("Expected task name 'test-task', got '%s'", retrievedTask.name)
}
}
// TestNewTaskMemoryMonitor tests the NewTaskMemoryMonitor function
func TestNewTaskMemoryMonitor(t *testing.T) {
// No mutex needed - this doesn't modify global state
logger := NewLogger("debug")
registry := GetGlobalTaskRegistry()
monitor := NewTaskMemoryMonitor(logger, registry)
if monitor == nil {
t.Error("Expected NewTaskMemoryMonitor to return non-nil monitor")
}
}
// TestGetCurrentStats tests the GetCurrentStats method
func TestGetCurrentStats(t *testing.T) {
// Don't hold mutex during background task execution to avoid deadlocks
logger := NewLogger("debug")
registry := GetGlobalTaskRegistry()
monitor := NewTaskMemoryMonitor(logger, registry)
// Start the monitor and let it collect at least one statistic
err := monitor.Start(50 * time.Millisecond)
if err != nil {
t.Fatalf("Failed to start monitor: %v", err)
}
// Ensure monitor is stopped even if test fails
defer func() {
monitor.Stop()
// Give extra time for cleanup
time.Sleep(50 * time.Millisecond)
}()
// Wait a bit for the monitor to collect stats
time.Sleep(150 * time.Millisecond)
stats, err := monitor.GetCurrentStats()
if err != nil {
// If no stats are available yet, that's acceptable for this test
t.Logf("No memory statistics available yet: %v", err)
return
}
// TaskMemoryStats is a struct, not a pointer, so it can't be nil
if stats.Timestamp.IsZero() {
t.Error("Expected GetCurrentStats to return valid timestamp")
}
}
// TestGetStatsHistory tests the GetStatsHistory method
func TestGetStatsHistory(t *testing.T) {
// No mutex needed - this just creates a monitor and checks its initial state
logger := NewLogger("debug")
registry := GetGlobalTaskRegistry()
monitor := NewTaskMemoryMonitor(logger, registry)
history := monitor.GetStatsHistory()
if history == nil {
t.Error("Expected GetStatsHistory to return non-nil history")
}
// A fresh monitor should have empty history
if len(history) != 0 {
t.Logf("History length: %d (may be non-empty due to shared global state)", len(history))
}
}
// TestForceGC tests the ForceGC method
func TestForceGC(t *testing.T) {
// No mutex needed - this doesn't modify global state
logger := NewLogger("debug")
registry := GetGlobalTaskRegistry()
monitor := NewTaskMemoryMonitor(logger, registry)
// This should not panic and should work
monitor.ForceGC()
// No specific verification needed, just ensuring it doesn't crash
}
// TestShutdownAllTasks tests the ShutdownAllTasks function
func TestShutdownAllTasks(t *testing.T) {
// Use a unique task name prefix to avoid conflicts with other tests
taskPrefix := "shutdown-test-"
// Create a temporary clean registry state
func() {
globalRegistryMutex.Lock()
defer globalRegistryMutex.Unlock()
ResetGlobalTaskRegistry()
}()
registry := GetGlobalTaskRegistry()
logger := NewLogger("debug")
// Create some test tasks with unique names
task1 := NewBackgroundTask(taskPrefix+"task1", time.Millisecond, func() {
time.Sleep(100 * time.Millisecond) // Simulate work
}, logger)
task2 := NewBackgroundTask(taskPrefix+"task2", time.Millisecond, func() {
time.Sleep(100 * time.Millisecond) // Simulate work
}, logger)
// Register tasks under mutex protection
func() {
globalRegistryMutex.Lock()
defer globalRegistryMutex.Unlock()
registry.RegisterTask(taskPrefix+"task1", task1)
registry.RegisterTask(taskPrefix+"task2", task2)
}()
// Start the tasks (outside mutex to avoid deadlock)
task1.Start()
task2.Start()
// Give tasks time to start
time.Sleep(50 * time.Millisecond)
// Shutdown all tasks
ShutdownAllTasks()
// Give shutdown time to complete
time.Sleep(200 * time.Millisecond)
// Note: We can't reliably verify task count due to other tests
// Just ensure shutdown doesn't panic
}
+9 -8
View File
@@ -58,12 +58,13 @@ func TestAzureOIDCRegression(t *testing.T) {
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
scopes: []string{"openid", "profile", "email"},
refreshGracePeriod: 60 * time.Second,
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Add rate limiter
logger: mockLogger,
httpClient: createDefaultHTTPClient(), // Add HTTP client
httpClient: CreateDefaultHTTPClient(), // Add HTTP client
jwkCache: &JWKCache{}, // Add JWK cache
tokenCache: tokenCache,
tokenBlacklist: tokenBlacklist,
@@ -78,7 +79,7 @@ func TestAzureOIDCRegression(t *testing.T) {
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
// Initialize session manager
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", mockLogger)
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", "", 0, mockLogger)
tOidc.sessionManager = sessionManager
// Mock the JWT verification to avoid JWKS lookup issues
@@ -329,12 +330,12 @@ func TestValidateGoogleTokens(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
name string
setupSession func() *SessionData
name string
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "ValidGoogleTokens",
@@ -475,13 +476,13 @@ func TestIsUserAuthenticated(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
setupSession func() *SessionData
name string
providerType string
setupSession func() *SessionData
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "AzureProvider",
@@ -659,12 +660,12 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
name string
setupSession func() *SessionData
name string
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "UnauthenticatedWithRefreshToken",
+545
View File
@@ -0,0 +1,545 @@
package traefikoidc
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestMemoryMonitorComprehensive tests memory monitor edge cases
func TestMemoryMonitorComprehensive(t *testing.T) {
t.Run("TriggerGC calls runtime GC", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Should not panic
assert.NotPanics(t, func() {
monitor.TriggerGC()
})
})
t.Run("GetMemoryPressure returns pressure level", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Initially should return None (no stats yet)
pressure := monitor.GetMemoryPressure()
assert.Equal(t, MemoryPressureNone, pressure)
// Explicitly sample to populate lastStats; GetCurrentStats is now a
// cached read and no longer forces a runtime.ReadMemStats.
monitor.Refresh()
// Now should return a valid pressure level
pressure = monitor.GetMemoryPressure()
assert.NotNil(t, pressure)
})
t.Run("StartMonitoring can be called", func(t *testing.T) {
ResetGlobalMemoryMonitor()
ResetGlobalTaskRegistry()
defer ResetGlobalMemoryMonitor()
defer ResetGlobalTaskRegistry()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Start monitoring should not panic. Interval is clamped to the
// minimum (30s); we rely on Refresh() when we need a synchronous
// sample instead of waiting for a tick.
assert.NotPanics(t, func() {
ctx := context.Background()
monitor.StartMonitoring(ctx, 0)
monitor.Refresh()
})
// Clean up
monitor.StopMonitoring()
})
t.Run("StopMonitoring can be called safely", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// StopMonitoring should not panic even if not started
assert.NotPanics(t, func() {
monitor.StopMonitoring()
})
// Can be called multiple times safely
assert.NotPanics(t, func() {
monitor.StopMonitoring()
monitor.StopMonitoring()
})
})
t.Run("ResetGlobalMemoryMonitor resets singleton", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
// Get initial instance
GetGlobalMemoryMonitor()
// Reset
ResetGlobalMemoryMonitor()
// Should be able to get a new instance
monitor := GetGlobalMemoryMonitor()
assert.NotNil(t, monitor)
// Clean up
monitor.StopMonitoring()
ResetGlobalMemoryMonitor()
})
t.Run("String method returns pressure name", func(t *testing.T) {
pressures := []struct {
name string
level MemoryPressureLevel
}{
{level: MemoryPressureNone, name: "None"},
{level: MemoryPressureLow, name: "Low"},
{level: MemoryPressureModerate, name: "Moderate"},
{level: MemoryPressureHigh, name: "High"},
{level: MemoryPressureCritical, name: "Critical"},
{level: MemoryPressureLevel(999), name: "Unknown"},
}
for _, p := range pressures {
assert.Equal(t, p.name, p.level.String(), "pressure level %d should return %s", p.level, p.name)
}
})
t.Run("GetCurrentStats collects statistics", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Refresh forces a synchronous sample; GetCurrentStats is a cached
// read, so we sample first to guarantee fresh data.
monitor.Refresh()
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
assert.Greater(t, stats.NumGoroutines, 0)
assert.NotZero(t, stats.Timestamp)
})
}
// TestBackgroundTaskRegistry tests background task registry edge cases
func TestBackgroundTaskRegistry(t *testing.T) {
t.Run("GetGlobalTaskRegistry returns singleton", func(t *testing.T) {
registry1 := GetGlobalTaskRegistry()
registry2 := GetGlobalTaskRegistry()
assert.Equal(t, registry1, registry2, "should return same instance")
})
t.Run("RegisterTask adds task to registry", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
taskName := "test-register-task"
task := NewBackgroundTask(
taskName,
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
err := registry.RegisterTask(taskName, task)
assert.NoError(t, err)
// Verify task was registered
_, exists := registry.GetTask(taskName)
assert.True(t, exists, "task should be registered")
// Clean up
task.Stop()
})
t.Run("CreateSingletonTask is idempotent", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
taskName := "test-singleton-idempotent"
callCount := 0
var mu sync.Mutex
taskFunc := func() {
mu.Lock()
callCount++
mu.Unlock()
}
// First creation should succeed
task1, err1 := registry.CreateSingletonTask(
taskName,
100*time.Millisecond,
taskFunc,
newNoOpLogger(),
nil,
)
assert.NoError(t, err1)
assert.NotNil(t, task1)
// Second creation should also succeed (idempotent)
// Returns same task without error
task2, err2 := registry.CreateSingletonTask(
taskName,
100*time.Millisecond,
taskFunc,
newNoOpLogger(),
nil,
)
assert.NoError(t, err2, "CreateSingletonTask should be idempotent")
assert.NotNil(t, task2)
// Clean up
if task1 != nil {
task1.Stop()
}
})
t.Run("GetTaskCount returns active task count", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
// Initially should be 0 or small number
initialCount := registry.GetTaskCount()
// Create a task
task := NewBackgroundTask(
"count-test-task",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
err := registry.RegisterTask("count-test-task", task)
assert.NoError(t, err)
// Count should increase
newCount := registry.GetTaskCount()
assert.Equal(t, initialCount+1, newCount)
// Clean up
task.Stop()
})
t.Run("StopAllTasks stops all tasks", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
// Create multiple tasks
for i := 0; i < 3; i++ {
taskName := "multi-task-" + string(rune(i+'0'))
task := NewBackgroundTask(
taskName,
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
registry.RegisterTask(taskName, task)
}
// Verify tasks were created
assert.GreaterOrEqual(t, registry.GetTaskCount(), 3)
// Stop all tasks
registry.StopAllTasks()
// Verify all tasks are removed
taskCount := registry.GetTaskCount()
assert.Equal(t, 0, taskCount, "all tasks should be stopped")
})
t.Run("ResetGlobalTaskRegistry clears registry", func(t *testing.T) {
ResetGlobalTaskRegistry()
registry := GetGlobalTaskRegistry()
// Create a task
task := NewBackgroundTask(
"reset-test-task",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
registry.RegisterTask("reset-test-task", task)
// Reset
ResetGlobalTaskRegistry()
// Get new registry
newRegistry := GetGlobalTaskRegistry()
assert.Equal(t, 0, newRegistry.GetTaskCount(), "new registry should be empty")
})
}
// TestBackgroundTaskLifecycle tests background task lifecycle
func TestBackgroundTaskLifecycle(t *testing.T) {
t.Run("Start begins task execution", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
executed := false
var mu sync.Mutex
task := NewBackgroundTask(
"lifecycle-test",
50*time.Millisecond,
func() {
mu.Lock()
executed = true
mu.Unlock()
},
newNoOpLogger(),
)
// Start task
task.Start()
// Wait for execution
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Stop task
task.Stop()
// Verify it executed
mu.Lock()
wasExecuted := executed
mu.Unlock()
assert.True(t, wasExecuted, "task should have executed")
})
t.Run("Stop halts task execution", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
execCount := 0
var mu sync.Mutex
task := NewBackgroundTask(
"stop-test",
30*time.Millisecond,
func() {
mu.Lock()
execCount++
mu.Unlock()
},
newNoOpLogger(),
)
// Start task
task.Start()
// Let it run a few times
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Stop task
task.Stop()
// Record count
mu.Lock()
countAfterStop := execCount
mu.Unlock()
// Wait more
time.Sleep(GetTestDuration(100 * time.Millisecond))
// Count should not increase
mu.Lock()
finalCount := execCount
mu.Unlock()
assert.Equal(t, countAfterStop, finalCount, "task should not execute after stop")
})
t.Run("Multiple Start calls are safe", func(t *testing.T) {
if testing.Short() {
t.Skip("Skipping background task test in short mode")
}
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
execCount := 0
var mu sync.Mutex
task := NewBackgroundTask(
"multi-start-test",
100*time.Millisecond,
func() {
mu.Lock()
execCount++
mu.Unlock()
},
newNoOpLogger(),
)
// Multiple starts should be safe
task.Start()
task.Start()
task.Start()
// Wait a bit
time.Sleep(GetTestDuration(50 * time.Millisecond))
// Stop task
task.Stop()
// Should have executed, but only one goroutine
mu.Lock()
count := execCount
mu.Unlock()
assert.GreaterOrEqual(t, count, 0, "task should have executed at least once")
})
t.Run("Multiple Stop calls are safe", func(t *testing.T) {
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
task := NewBackgroundTask(
"multi-stop-test",
100*time.Millisecond,
func() {},
newNoOpLogger(),
)
// Start and stop
task.Start()
time.Sleep(GetTestDuration(20 * time.Millisecond))
// Multiple stops should be safe
assert.NotPanics(t, func() {
task.Stop()
task.Stop()
task.Stop()
})
})
}
// TestMemoryMonitorIntegration tests memory monitor integration
func TestMemoryMonitorIntegration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory monitor integration test in short mode")
}
t.Run("monitoring updates stats", func(t *testing.T) {
ResetGlobalMemoryMonitor()
ResetGlobalTaskRegistry()
defer ResetGlobalMemoryMonitor()
defer ResetGlobalTaskRegistry()
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
defer monitor.StopMonitoring()
// Start monitoring. The interval is clamped to the minimum (30s) so
// the ticker won't fire during the test; drive the sample manually via
// Refresh() instead.
ctx := context.Background()
monitor.StartMonitoring(ctx, 0)
monitor.Refresh()
// Get pressure (should be a valid pressure level)
pressure := monitor.GetMemoryPressure()
assert.Contains(t, []MemoryPressureLevel{
MemoryPressureNone,
MemoryPressureLow,
MemoryPressureModerate,
MemoryPressureHigh,
MemoryPressureCritical,
}, pressure, "pressure should be a valid level")
// Stop monitoring
monitor.StopMonitoring()
})
t.Run("global memory monitor singleton", func(t *testing.T) {
ResetGlobalMemoryMonitor()
defer ResetGlobalMemoryMonitor()
monitor1 := GetGlobalMemoryMonitor()
monitor2 := GetGlobalMemoryMonitor()
assert.Equal(t, monitor1, monitor2, "should return same instance")
})
}
// TestMemoryStatsCollection tests memory statistics collection
func TestMemoryStatsCollection(t *testing.T) {
t.Run("GetCurrentStats returns valid data", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
monitor.Refresh()
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
assert.Greater(t, stats.HeapSysBytes, uint64(0))
assert.Greater(t, stats.NumGoroutines, 0)
assert.False(t, stats.Timestamp.IsZero())
})
t.Run("Stats include memory pressure", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
monitor.Refresh()
stats := monitor.GetCurrentStats()
// Should calculate and include pressure level
assert.NotNil(t, stats.MemoryPressure)
assert.Contains(t, []MemoryPressureLevel{
MemoryPressureNone,
MemoryPressureLow,
MemoryPressureModerate,
MemoryPressureHigh,
MemoryPressureCritical,
}, stats.MemoryPressure)
})
t.Run("TriggerGC reduces memory", func(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Allocate some memory
_ = make([]byte, 1024*1024) // 1MB
// Get stats before GC (explicit Refresh so we have a fresh pre-GC
// snapshot to compare against, not the constructor baseline).
beforeStats := monitor.Refresh()
// Trigger GC (internally Refresh()es before and after)
monitor.TriggerGC()
// Get stats after GC from cache (TriggerGC already refreshed it)
afterStats := monitor.GetCurrentStats()
// After GC should have different stats
assert.NotEqual(t, beforeStats.LastGCTime, afterStats.LastGCTime)
})
}
+592
View File
@@ -0,0 +1,592 @@
// Package traefikoidc — bearer-token (M2M) authentication path.
//
// Disabled by default. When enabled via Config.EnableBearerAuth, requests
// presenting "Authorization: Bearer <jwt>" are validated against the
// configured OIDC provider (signature, issuer, audience, exp, replay-Get)
// and the request is forwarded downstream without creating a cookie session.
//
// Design rules (kept here in code as the single source of truth):
// - Access tokens only. ID tokens are rejected via detectTokenType.
// - Audience is mandatory (enforced at startup in main.go).
// - alg + kid pinned BEFORE JWKS fetch to deny amplification probes.
// - iat upper-age cap bounds clock-skew / forever-token abuse.
// - Multi-audience tokens require matching azp.
// - Per-IP 401 throttle returns 429 + Retry-After after a threshold.
// - JTI Set is suppressed (skipReplayMarking) but JTI Get stays — revoked
// tokens (RevokeToken adds to blacklist) are still rejected.
// - Identifier is read from BearerIdentifierClaim (default "sub"), never
// from UserIdentifierClaim, to avoid the unverified-email spoofing path.
// - Identifier is sanitized: length cap, control chars, bidi-override,
// delimiter chars (, ; =) rejected.
// - On excluded URLs the Authorization header is stripped before forwarding.
//
// See docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md and
// docs/BEARER_AUTH.md for the full threat model.
package traefikoidc
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"unicode"
)
const bearerPrefix = "Bearer "
// bearerAlgAllowlist is the set of JWS algorithms accepted on the bearer
// path. Asymmetric-only — HS* would allow public-key-as-HMAC-secret attacks
// if any operator ever rotates a key into the symmetric branch by mistake;
// "none" is obvious. Matches the allowlist enforced inside jwt.Verify but is
// checked here BEFORE the JWKS fetch so attacker noise can't amplify.
var bearerAlgAllowlist = map[string]struct{}{
"RS256": {}, "RS384": {}, "RS512": {},
"PS256": {}, "PS384": {}, "PS512": {},
"ES256": {}, "ES384": {}, "ES512": {},
}
// bearerKidMaxLen caps the JOSE kid header length to keep memory and cache-key
// usage bounded against attacker-controlled values.
const bearerKidMaxLen = 256
// validKidChar is the allowlist for kid header characters. Letters, digits,
// dot, underscore, hyphen, equals. Intentionally narrow; real-world kid
// values are short URL-safe-base64-ish identifiers.
func validKidChar(r rune) bool {
if r >= 'a' && r <= 'z' {
return true
}
if r >= 'A' && r <= 'Z' {
return true
}
if r >= '0' && r <= '9' {
return true
}
switch r {
case '.', '_', '-', '=':
return true
}
return false
}
// bearerError categorizes failure modes for the response builder. Categories
// map 1:1 to the table in docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md
// §9 so behavior is auditable from spec to code.
type bearerErrorKind int
const (
bearerErrInvalidRequest bearerErrorKind = iota
bearerErrInvalidToken
bearerErrTokenInactive
bearerErrInvalidIdentifier
bearerErrForbidden
bearerErrThrottled
bearerErrIntrospectionUnavailable
)
type bearerError struct {
kind bearerErrorKind
reason string
}
func (e *bearerError) Error() string { return e.reason }
func newBearerError(kind bearerErrorKind, reason string) *bearerError {
return &bearerError{kind: kind, reason: reason}
}
// joseHeader is the minimal subset of the JWS protected header we inspect
// BEFORE running the full verification pipeline. Lifted out so the alg+kid
// pin can run without paying for parseJWT's full claim decode.
type joseHeader struct {
Alg string `json:"alg"`
Kid string `json:"kid"`
Typ string `json:"typ"`
}
// parseBearerJOSEHeader decodes the first JWT segment for early alg/kid pinning.
// Does not touch the payload or signature — those are the verifier's job.
// Returns nil on success; *bearerError on rejection so the handler can map
// directly to a status code. The decoded header itself is not surfaced because
// callers don't need it (verifyTokenWithOpts re-parses internally).
func parseBearerJOSEHeader(token string) *bearerError {
dot := strings.IndexByte(token, '.')
if dot <= 0 {
return newBearerError(bearerErrInvalidToken, "malformed JWT: no header segment")
}
raw, err := base64.RawURLEncoding.DecodeString(token[:dot])
if err != nil {
// Some IdPs pad with '='; tolerate by retrying with StdEncoding.
raw, err = base64.URLEncoding.DecodeString(token[:dot])
if err != nil {
return newBearerError(bearerErrInvalidToken, "malformed JWT: header not base64url")
}
}
var hdr joseHeader
if err := json.Unmarshal(raw, &hdr); err != nil {
return newBearerError(bearerErrInvalidToken, "malformed JWT: header not JSON")
}
if _, ok := bearerAlgAllowlist[hdr.Alg]; !ok {
return newBearerError(bearerErrInvalidToken, fmt.Sprintf("disallowed alg %q on bearer path", hdr.Alg))
}
if hdr.Kid == "" {
return newBearerError(bearerErrInvalidToken, "missing kid header")
}
if len(hdr.Kid) > bearerKidMaxLen {
return newBearerError(bearerErrInvalidToken, "kid header exceeds max length")
}
for _, r := range hdr.Kid {
if !validKidChar(r) {
return newBearerError(bearerErrInvalidToken, "kid header contains disallowed characters")
}
}
return nil
}
// sanitizeBearerIdentifier validates and trims a principal identifier before
// it is injected into request headers. Layered defense: net/http will reject
// CRLF on the wire too, but rejecting early gives clearer error logs and
// prevents bidi-override / delimiter chars that pass net/http's narrower
// checks but confuse downstream parsers and admin UIs.
func sanitizeBearerIdentifier(raw string, maxLen int) (string, *bearerError) {
identifier := strings.TrimSpace(raw)
if identifier == "" {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier claim empty")
}
if maxLen > 0 && len(identifier) > maxLen {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier exceeds max length")
}
for _, r := range identifier {
if unicode.IsControl(r) {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains control character")
}
// Unicode bidi-override range (RTL spoofing of admin UI / SIEM).
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains bidi-override character")
}
if r == ',' || r == ';' || r == '=' {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains delimiter character")
}
}
return identifier, nil
}
// resolveBearerIdentifier picks the principal identifier from claims using
// the configured BearerIdentifierClaim (default "sub"). Decoupled from
// userIdentifierClaim (cookie path) to avoid the unverified-email spoofing
// vector documented in the spec §13.
func resolveBearerIdentifier(claims map[string]interface{}, claimName string) (string, *bearerError) {
if claimName == "" {
claimName = "sub"
}
raw, ok := claims[claimName]
if !ok {
return "", newBearerError(bearerErrInvalidIdentifier, fmt.Sprintf("missing claim %q", claimName))
}
str, ok := raw.(string)
if !ok {
return "", newBearerError(bearerErrInvalidIdentifier, fmt.Sprintf("claim %q not a string", claimName))
}
return str, nil
}
// enforceMultiAudienceAzp implements the spec hardening: when aud is a
// multi-element array, require an azp claim equal to clientID. Single-string
// aud is unaffected (existing verifyAudience handles it).
func enforceMultiAudienceAzp(claims map[string]interface{}, clientID string) *bearerError {
audRaw, ok := claims["aud"]
if !ok {
return nil // verifyToken already rejects missing aud
}
arr, ok := audRaw.([]interface{})
if !ok {
return nil // single-string aud
}
if len(arr) <= 1 {
return nil
}
azpRaw, ok := claims["azp"]
if !ok {
return newBearerError(bearerErrInvalidToken, "multi-audience token missing azp")
}
azp, ok := azpRaw.(string)
if !ok || azp == "" {
return newBearerError(bearerErrInvalidToken, "multi-audience token has empty/non-string azp")
}
if azp != clientID {
return newBearerError(bearerErrInvalidToken, "multi-audience token azp does not match clientID")
}
return nil
}
// enforceIatAge implements the spec MaxTokenAgeSeconds bound on iat. Bounds
// clock-manipulation / forever-token abuse without rejecting tokens with a
// normal iat just because the issuer's clock skews a few seconds.
func enforceIatAge(claims map[string]interface{}, maxAge time.Duration) *bearerError {
if maxAge <= 0 {
return nil
}
iatRaw, ok := claims["iat"].(float64)
if !ok {
// jwt.Verify already requires iat; this branch shouldn't be reached.
return newBearerError(bearerErrInvalidToken, "missing iat claim")
}
iat := time.Unix(int64(iatRaw), 0)
if time.Since(iat) > maxAge {
return newBearerError(bearerErrInvalidToken, "token iat outside age bound")
}
return nil
}
// hashIdentifierForLog returns a short SHA-256 prefix safe for info-level
// logs. Full identifier is only emitted at debug. Satisfies the audit
// requirement (trace which principal was rejected) without leaking PII.
func hashIdentifierForLog(identifier string) string {
if identifier == "" {
return "(none)"
}
sum := sha256.Sum256([]byte(identifier))
return hex.EncodeToString(sum[:4]) // 8 hex chars
}
// --- Per-IP failure throttle ---
// bearerFailureTracker records consecutive bearer-auth 401s per source IP and
// parks repeat offenders in a 429 penalty box. Limits offline-guessing-style
// attacks and protects the shared rate-limiter / JWKS endpoint from being
// burned by a single source.
type bearerFailureTracker struct {
mu sync.Mutex
entries map[string]*bearerFailureEntry
// Configuration snapshot. Captured at construction so a hot reconfigure
// doesn't race with the per-request paths.
threshold int
window time.Duration
penalty time.Duration
}
type bearerFailureEntry struct {
firstFailureAt time.Time
penaltyUntil time.Time
count int
}
func newBearerFailureTracker(threshold int, window, penalty time.Duration) *bearerFailureTracker {
if threshold <= 0 {
threshold = 20
}
if window <= 0 {
window = 60 * time.Second
}
if penalty <= 0 {
penalty = 60 * time.Second
}
return &bearerFailureTracker{
entries: make(map[string]*bearerFailureEntry),
threshold: threshold,
window: window,
penalty: penalty,
}
}
// blocked reports whether the source IP is currently in the penalty box.
// Returns (true, retryAfter) when blocked; (false, 0) when allowed.
func (b *bearerFailureTracker) blocked(ip string) (bool, time.Duration) {
if b == nil || ip == "" {
return false, 0
}
b.mu.Lock()
defer b.mu.Unlock()
e, ok := b.entries[ip]
if !ok {
return false, 0
}
now := time.Now()
if !e.penaltyUntil.IsZero() && now.Before(e.penaltyUntil) {
return true, time.Until(e.penaltyUntil)
}
return false, 0
}
// recordFailure increments the failure counter for the given IP and trips
// the penalty box once threshold-within-window is exceeded.
func (b *bearerFailureTracker) recordFailure(ip string) {
if b == nil || ip == "" {
return
}
b.mu.Lock()
defer b.mu.Unlock()
now := time.Now()
e, ok := b.entries[ip]
if !ok || now.Sub(e.firstFailureAt) > b.window {
e = &bearerFailureEntry{firstFailureAt: now}
b.entries[ip] = e
}
e.count++
if e.count >= b.threshold {
e.penaltyUntil = now.Add(b.penalty)
}
}
// recordSuccess clears the failure counter for the given IP after a
// successful bearer auth.
func (b *bearerFailureTracker) recordSuccess(ip string) {
if b == nil || ip == "" {
return
}
b.mu.Lock()
defer b.mu.Unlock()
delete(b.entries, ip)
}
// clientIPForBearer returns the source IP used to key the failure tracker.
// Trusts only the request's transport-level RemoteAddr; X-Forwarded-For is
// intentionally ignored to avoid attacker-controlled key spoofing. Behind a
// trusted reverse proxy where every request shares one IP, the throttle is
// still useful (caps attacker churn through that proxy) — operators wanting
// per-real-client throttling must terminate at this middleware.
func clientIPForBearer(req *http.Request) string {
if req == nil {
return ""
}
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
return req.RemoteAddr
}
return host
}
// --- Bearer auth entrypoint ---
// detectBearerToken returns (token, true) when the request carries a usable
// Authorization: Bearer header. Case-insensitive on the scheme. Returns
// ("", false) for any other shape.
func detectBearerToken(req *http.Request) (string, bool) {
if req == nil {
return "", false
}
h := req.Header.Get("Authorization")
if len(h) < len(bearerPrefix) {
return "", false
}
if !strings.EqualFold(h[:len(bearerPrefix)], bearerPrefix) {
return "", false
}
token := strings.TrimSpace(h[len(bearerPrefix):])
if token == "" {
return "", false
}
return token, true
}
// hasSessionCookie reports whether the request carries any cookie matching
// the session prefix. Used to implement the cookie-wins-by-default
// precedence rule when both bearer and cookie are present.
func (t *TraefikOidc) hasSessionCookie(req *http.Request) bool {
if t.sessionManager == nil {
return false
}
prefix := t.sessionManager.GetCookiePrefix()
if prefix == "" {
return false
}
for _, c := range req.Cookies() {
if strings.HasPrefix(c.Name, prefix) {
return true
}
}
return false
}
// writeBearerError writes the canonical 401/403/429/503 response per spec §9.
// Body is always generic; reason is logged at debug only. The
// WWW-Authenticate hint is gated by config (default on, RFC 6750 compliant).
func (t *TraefikOidc) writeBearerError(rw http.ResponseWriter, req *http.Request, err *bearerError) {
var (
status int
errCode string
body string
retryAfter time.Duration
)
switch err.kind {
case bearerErrInvalidRequest:
status = http.StatusUnauthorized
errCode = "invalid_request"
body = "Unauthorized"
case bearerErrInvalidToken, bearerErrTokenInactive, bearerErrInvalidIdentifier:
status = http.StatusUnauthorized
errCode = "invalid_token"
body = "Unauthorized"
case bearerErrForbidden:
status = http.StatusForbidden
body = "Access denied"
case bearerErrThrottled:
status = http.StatusTooManyRequests
body = "Too Many Requests"
retryAfter = t.bearerFailurePenalty
case bearerErrIntrospectionUnavailable:
status = http.StatusServiceUnavailable
body = "Service Unavailable"
default:
status = http.StatusUnauthorized
body = "Unauthorized"
}
if t.bearerEmitWWWAuthenticate && errCode != "" {
rw.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer error=%q`, errCode))
}
if retryAfter > 0 {
rw.Header().Set("Retry-After", fmt.Sprintf("%d", int(retryAfter.Seconds())))
}
rw.Header().Set("Content-Type", "text/plain; charset=utf-8")
rw.WriteHeader(status)
_, _ = rw.Write([]byte(body)) // Safe to ignore: best-effort error body write
if t.logger != nil {
t.logger.Debugf("bearer auth rejected: status=%d category=%v reason=%q path=%s",
status, err.kind, err.reason, req.URL.Path)
}
}
// handleBearerRequest is the entry point invoked by ServeHTTP when the
// EnableBearerAuth flag is set, the request carries an Authorization: Bearer
// header, and the (configurable) cookie-precedence rule allows the bearer
// path to run.
func (t *TraefikOidc) handleBearerRequest(rw http.ResponseWriter, req *http.Request) {
ip := clientIPForBearer(req)
if blocked, retryAfter := t.bearerFailureTracker.blocked(ip); blocked {
throttled := newBearerError(bearerErrThrottled, "ip in penalty box")
// Preserve the actual retry-after even if it diverged from the
// configured default (clock-skew, partial-window expiry).
if retryAfter > 0 {
rw.Header().Set("Retry-After", fmt.Sprintf("%d", int(retryAfter.Seconds())))
}
t.writeBearerError(rw, req, throttled)
return
}
token, ok := detectBearerToken(req)
if !ok {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidRequest, "missing or empty bearer token"))
return
}
if len(token) > AccessTokenConfig.MaxLength {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidToken, "token exceeds max length"))
return
}
if strings.Count(token, ".") != 2 {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidToken, "token is not a 3-segment JWT"))
return
}
if bErr := parseBearerJOSEHeader(token); bErr != nil {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, bErr)
return
}
p, bErr := t.buildPrincipalFromBearerToken(token)
if bErr != nil {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, bErr)
return
}
t.bearerFailureTracker.recordSuccess(ip)
if t.logger != nil {
t.logger.Debugf("bearer auth success: identifier_hash=%s path=%s",
hashIdentifierForLog(p.Identifier), req.URL.Path)
}
t.forwardAuthorized(rw, req, p)
}
// buildPrincipalFromBearerToken runs the full bearer verification pipeline
// described in spec §7.3 and returns a principal ready for forwardAuthorized.
// Returns a typed *bearerError on failure so the caller can map to status.
func (t *TraefikOidc) buildPrincipalFromBearerToken(token string) (*principal, *bearerError) {
if err := t.verifyTokenWithOpts(token, verifyOpts{skipReplayMarking: true}); err != nil {
return nil, newBearerError(bearerErrInvalidToken, "token verification failed: "+err.Error())
}
parsed, err := parseJWT(token)
if err != nil {
return nil, newBearerError(bearerErrInvalidToken, "post-verify parseJWT failed: "+err.Error())
}
claims := parsed.Claims
// Token-type guard. Reuse the well-tested classifier which already
// checks nonce / typ=at+jwt / token_use / scope / aud-vs-clientID.
if t.detectTokenType(parsed, token) {
return nil, newBearerError(bearerErrInvalidToken, "ID tokens are not accepted on the bearer path")
}
// Belt-and-braces explicit rejection (cheap, catches edge cases not
// covered by detectTokenType's heuristic).
if nonce, ok := claims["nonce"].(string); ok && nonce != "" {
return nil, newBearerError(bearerErrInvalidToken, "nonce claim present (ID-token shape)")
}
if tu, ok := claims["token_use"].(string); ok && tu == "id" {
return nil, newBearerError(bearerErrInvalidToken, "token_use=id rejected")
}
if bErr := enforceMultiAudienceAzp(claims, t.clientID); bErr != nil {
return nil, bErr
}
if bErr := enforceIatAge(claims, t.maxTokenAge); bErr != nil {
return nil, bErr
}
if t.requireTokenIntrospection {
if bErr := t.introspectOnBearerPath(token); bErr != nil {
return nil, bErr
}
}
rawIdentifier, bErr := resolveBearerIdentifier(claims, t.bearerIdentifierClaim)
if bErr != nil {
return nil, bErr
}
identifier, bErr := sanitizeBearerIdentifier(rawIdentifier, t.maxIdentifierLength)
if bErr != nil {
return nil, bErr
}
subject, _ := claims["sub"].(string)
clientID, _ := claims["azp"].(string)
if clientID == "" {
clientID, _ = claims["client_id"].(string)
}
return &principal{
Source: sourceBearer,
Identifier: identifier,
Subject: subject,
ClientID: clientID,
Claims: claims,
AccessToken: token,
}, nil
}
// introspectOnBearerPath calls the existing RFC 7662 introspector when the
// operator demands real-time revocation. Distinguishes "token revoked" (401)
// from "endpoint unavailable" (503) so transient infra failures don't look
// like credential failures.
func (t *TraefikOidc) introspectOnBearerPath(token string) *bearerError {
resp, err := t.introspectToken(token)
if err != nil {
return newBearerError(bearerErrIntrospectionUnavailable, "introspection failed: "+err.Error())
}
if !resp.Active {
return newBearerError(bearerErrTokenInactive, "introspection reports token inactive")
}
return nil
}
+812
View File
@@ -0,0 +1,812 @@
package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"golang.org/x/time/rate"
)
// =============================================================================
// Helper builders
// =============================================================================
// makeBearerJWT constructs a JWT with explicit header + claims for tests.
// Signature is opaque (b64("signature")) — bearer tests don't exercise the
// real cryptographic verifier; verification is bypassed via tokenCache pre-
// seed so the bearer pipeline under test sees a "verified" token.
func makeBearerJWT(t *testing.T, header, claims map[string]interface{}) string {
t.Helper()
hb, err := json.Marshal(header)
if err != nil {
t.Fatalf("marshal header: %v", err)
}
cb, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
return fmt.Sprintf("%s.%s.%s",
base64.RawURLEncoding.EncodeToString(hb),
base64.RawURLEncoding.EncodeToString(cb),
base64.RawURLEncoding.EncodeToString([]byte("signature")),
)
}
// defaultBearerHeader produces the standard RS256+kid header used in tests.
func defaultBearerHeader() map[string]interface{} {
return map[string]interface{}{"alg": "RS256", "kid": "test-kid"}
}
// defaultBearerClaims produces a baseline access-token claim set. Tests
// shallow-clone and override fields as needed.
func defaultBearerClaims() map[string]interface{} {
return map[string]interface{}{
"iss": "https://issuer.example.com",
"aud": "https://api.example.com",
"sub": "service-account-1",
"scope": "api:read api:write",
"exp": float64(time.Now().Add(time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
}
}
// makeBearerOIDC constructs a TraefikOidc wired for bearer auth tests. The
// real verifyTokenWithOpts pipeline is short-circuited via tokenCache pre-
// seed: any token Set into t.tokenCache returns nil from VerifyToken,
// letting tests exercise the post-verify bearer logic (classifier, identifier,
// throttle, header forwarding) without standing up JWKs.
func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
t.Helper()
sm := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("error"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://issuer.example.com",
audience: "https://api.example.com",
clientID: "https://api.example.com",
tokenCache: NewTokenCache(),
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
allowedRolesAndGroups: map[string]struct{}{},
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
ctx: context.Background(),
enableBearerAuth: true,
stripAuthorizationHeader: true,
bearerEmitWWWAuthenticate: true,
bearerOverridesCookie: false,
bearerIdentifierClaim: "sub",
maxIdentifierLength: 256,
maxTokenAge: 24 * time.Hour,
bearerFailureThreshold: 20,
bearerFailureWindow: 60 * time.Second,
bearerFailurePenalty: 60 * time.Second,
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
}
oidc.extractClaimsFunc = extractClaims
close(oidc.initComplete)
return oidc
}
// seedVerified pre-populates the tokenCache so verifyTokenWithOpts short-
// circuits to nil for the given token. Mirrors the production fast-return
// path at token_manager.go for previously-verified tokens.
func seedVerified(t *testing.T, oidc *TraefikOidc, token string, claims map[string]interface{}) {
t.Helper()
if oidc.tokenCache == nil {
oidc.tokenCache = NewTokenCache()
}
oidc.tokenCache.Set(token, claims, time.Hour)
}
// =============================================================================
// Unit tests — small helpers
// =============================================================================
func TestDetectBearerToken(t *testing.T) {
t.Parallel()
cases := []struct {
name string
header string
want string
ok bool
}{
{"missing header", "", "", false},
{"basic auth", "Basic abc", "", false},
{"bearer with token", "Bearer abc.def.ghi", "abc.def.ghi", true},
{"lowercase bearer", "bearer abc.def.ghi", "abc.def.ghi", true},
{"mixed case", "BeArEr abc.def.ghi", "abc.def.ghi", true},
{"empty token after prefix", "Bearer ", "", false},
{"bearer no space", "Bearerabc", "", false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if tc.header != "" {
req.Header.Set("Authorization", tc.header)
}
got, ok := detectBearerToken(req)
if ok != tc.ok || got != tc.want {
t.Fatalf("got=(%q, %v), want=(%q, %v)", got, ok, tc.want, tc.ok)
}
})
}
}
func TestParseBearerJOSEHeader(t *testing.T) {
t.Parallel()
mk := func(t *testing.T, h map[string]interface{}) string {
return makeBearerJWT(t, h, map[string]interface{}{"sub": "x"})
}
cases := []struct {
header map[string]interface{}
name string
wantErr bool
}{
{name: "valid RS256", header: map[string]interface{}{"alg": "RS256", "kid": "k1"}, wantErr: false},
{name: "valid ES512", header: map[string]interface{}{"alg": "ES512", "kid": "abc-_.="}, wantErr: false},
{name: "alg=none rejected", header: map[string]interface{}{"alg": "none", "kid": "k1"}, wantErr: true},
{name: "alg=HS256 rejected", header: map[string]interface{}{"alg": "HS256", "kid": "k1"}, wantErr: true},
{name: "missing kid", header: map[string]interface{}{"alg": "RS256"}, wantErr: true},
{name: "kid too long", header: map[string]interface{}{"alg": "RS256", "kid": strings.Repeat("a", bearerKidMaxLen+1)}, wantErr: true},
{name: "kid bad chars", header: map[string]interface{}{"alg": "RS256", "kid": "evil/../etc/passwd"}, wantErr: true},
{name: "kid with space", header: map[string]interface{}{"alg": "RS256", "kid": "key one"}, wantErr: true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
token := mk(t, tc.header)
err := parseBearerJOSEHeader(token)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
})
}
}
func TestSanitiseBearerIdentifier(t *testing.T) {
t.Parallel()
cases := []struct {
name string
in string
want string
wantErr bool
}{
{"normal sub", "service-account-1", "service-account-1", false},
{"email-like", "alice@example.com", "alice@example.com", false},
{"trim whitespace", " abc ", "abc", false},
{"empty", "", "", true},
{"only whitespace", " ", "", true},
{"control char (newline)", "alice\nbob", "", true},
{"control char (CR)", "alice\rbob", "", true},
{"control char (NUL)", "alice\x00bob", "", true},
{"bidi override", "alice\u202ebob", "", true},
{"bidi isolate", "alice\u2066bob", "", true},
{"comma delimiter", "alice,bob", "", true},
{"semicolon delimiter", "alice;bob", "", true},
{"equals delimiter", "alice=bob", "", true},
{"over length", strings.Repeat("a", 257), "", true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := sanitizeBearerIdentifier(tc.in, 256)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
if !tc.wantErr && got != tc.want {
t.Fatalf("got=%q want=%q", got, tc.want)
}
})
}
}
func TestResolveBearerIdentifier(t *testing.T) {
t.Parallel()
cases := []struct {
claims map[string]interface{}
name string
claim string
want string
wantErr bool
}{
{name: "default sub", claims: map[string]interface{}{"sub": "abc"}, claim: "", want: "abc"},
{name: "explicit sub", claims: map[string]interface{}{"sub": "abc"}, claim: "sub", want: "abc"},
{name: "custom client_id claim", claims: map[string]interface{}{"client_id": "svc"}, claim: "client_id", want: "svc"},
{name: "missing claim", claims: map[string]interface{}{"other": "x"}, claim: "sub", wantErr: true},
{name: "non-string claim", claims: map[string]interface{}{"sub": 123}, claim: "sub", wantErr: true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := resolveBearerIdentifier(tc.claims, tc.claim)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
if !tc.wantErr && got != tc.want {
t.Fatalf("got=%q want=%q", got, tc.want)
}
})
}
}
func TestEnforceMultiAudienceAzp(t *testing.T) {
t.Parallel()
const cid = "https://api.example.com"
cases := []struct {
claims map[string]interface{}
name string
wantErr bool
}{
{name: "single string aud", claims: map[string]interface{}{"aud": "x"}, wantErr: false},
{name: "single element array", claims: map[string]interface{}{"aud": []interface{}{"x"}}, wantErr: false},
{name: "multi-aud with matching azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": cid}, wantErr: false},
{name: "multi-aud missing azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}}, wantErr: true},
{name: "multi-aud empty azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": ""}, wantErr: true},
{name: "multi-aud wrong azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": "other"}, wantErr: true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := enforceMultiAudienceAzp(tc.claims, cid)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
})
}
}
func TestEnforceIatAge(t *testing.T) {
t.Parallel()
now := time.Now()
cases := []struct {
name string
iat float64
maxAge time.Duration
wantErr bool
}{
{name: "fresh", iat: float64(now.Unix()), maxAge: time.Hour, wantErr: false},
{name: "23h59m old, max 24h", iat: float64(now.Add(-23*time.Hour - 59*time.Minute).Unix()), maxAge: 24 * time.Hour, wantErr: false},
{name: "25h old, max 24h", iat: float64(now.Add(-25 * time.Hour).Unix()), maxAge: 24 * time.Hour, wantErr: true},
{name: "1970 token", iat: float64(0), maxAge: 24 * time.Hour, wantErr: true},
{name: "maxAge disabled (0)", iat: float64(0), maxAge: 0, wantErr: false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := enforceIatAge(map[string]interface{}{"iat": tc.iat}, tc.maxAge)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
})
}
}
func TestBearerFailureTracker(t *testing.T) {
t.Parallel()
tr := newBearerFailureTracker(3, 60*time.Second, 60*time.Second)
const ip = "10.0.0.1"
// Below threshold: not blocked.
for i := 0; i < 2; i++ {
tr.recordFailure(ip)
if b, _ := tr.blocked(ip); b {
t.Fatalf("blocked too early after %d failures", i+1)
}
}
// Threshold reached: blocked.
tr.recordFailure(ip)
if b, retry := tr.blocked(ip); !b || retry <= 0 {
t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry)
}
// Success clears the counter.
tr.recordSuccess(ip)
if b, _ := tr.blocked(ip); b {
t.Fatalf("expected unblocked after success")
}
// Other IPs are unaffected.
if b, _ := tr.blocked("10.0.0.2"); b {
t.Fatalf("unrelated IP should not be blocked")
}
}
// =============================================================================
// Integration tests — full ServeHTTP via the bearer pipeline
// =============================================================================
func TestServeHTTP_Bearer_HappyPath(t *testing.T) {
t.Parallel()
var nextCalled atomic.Bool
var capturedHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled.Store(true)
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled.Load() {
t.Fatalf("expected next handler to run; got status=%d body=%q", rw.Code, rw.Body.String())
}
if rw.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", rw.Code)
}
if got := capturedHeaders.Get("X-Forwarded-User"); got != "service-account-1" {
t.Fatalf("X-Forwarded-User=%q, want service-account-1", got)
}
if got := capturedHeaders.Get("Authorization"); got != "" {
t.Fatalf("Authorization should be stripped, got=%q", got)
}
}
func TestServeHTTP_Bearer_StripAuthDisabled(t *testing.T) {
t.Parallel()
var capturedAuth string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
oidc.stripAuthorizationHeader = false
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !strings.HasPrefix(capturedAuth, "Bearer ") {
t.Fatalf("expected Authorization to be forwarded, got=%q", capturedAuth)
}
}
func TestServeHTTP_Bearer_RejectIDToken(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for ID token rejection")
})
oidc := makeBearerOIDC(t, next)
// ID-token shape: nonce claim present and no scope. detectTokenType
// returns true.
claims := map[string]interface{}{
"iss": "https://issuer.example.com",
"aud": "https://api.example.com",
"sub": "user-1",
"nonce": "n-0S6_WzA2Mj",
"exp": float64(time.Now().Add(time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
}
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
if wa := rw.Header().Get("WWW-Authenticate"); !strings.Contains(wa, `error="invalid_token"`) {
t.Fatalf("expected WWW-Authenticate invalid_token, got=%q", wa)
}
}
func TestServeHTTP_Bearer_AlgNoneRejected(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for alg=none")
})
oidc := makeBearerOIDC(t, next)
header := map[string]interface{}{"alg": "none", "kid": "k1"}
claims := defaultBearerClaims()
token := makeBearerJWT(t, header, claims)
// Even if we pre-seeded the cache, the early alg pin runs FIRST.
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_KidTooLongRejected(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for oversized kid")
})
oidc := makeBearerOIDC(t, next)
header := map[string]interface{}{"alg": "RS256", "kid": strings.Repeat("a", bearerKidMaxLen+1)}
claims := defaultBearerClaims()
token := makeBearerJWT(t, header, claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_MultiAudRequiresAzp(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for multi-aud without azp")
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["aud"] = []interface{}{"https://api.example.com", "https://other.example.com"}
delete(claims, "azp")
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_MultiAudWithAzpAccepted(t *testing.T) {
t.Parallel()
var nextCalled atomic.Bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled.Store(true)
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["aud"] = []interface{}{"https://api.example.com", "https://other.example.com"}
claims["azp"] = oidc.clientID
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusOK || !nextCalled.Load() {
t.Fatalf("expected 200 + next called; got status=%d called=%v", rw.Code, nextCalled.Load())
}
}
func TestServeHTTP_Bearer_IatTooOldRejected(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for old iat")
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["iat"] = float64(time.Now().Add(-25 * time.Hour).Unix())
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_IdentifierWithBidiRejected(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for bidi identifier")
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["sub"] = "alice\u202ebob"
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_ReplayRegression(t *testing.T) {
t.Parallel()
var successCount atomic.Int32
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
successCount.Add(1)
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["jti"] = "regression-jti"
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
for i := 0; i < 100; i++ {
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusOK {
t.Fatalf("iteration %d: status=%d, want 200", i, rw.Code)
}
}
if successCount.Load() != 100 {
t.Fatalf("successCount=%d, want 100", successCount.Load())
}
}
func TestServeHTTP_Bearer_ThrottleTrips429(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run during throttle test")
})
oidc := makeBearerOIDC(t, next)
oidc.bearerFailureTracker = newBearerFailureTracker(3, 60*time.Second, 60*time.Second)
// Send malformed bearers from the same RemoteAddr until threshold trips.
send := func() *httptest.ResponseRecorder {
req := httptest.NewRequest("GET", "/api/work", nil)
req.RemoteAddr = "10.0.0.5:1234"
req.Header.Set("Authorization", "Bearer not-a-jwt")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
return rw
}
for i := 0; i < 3; i++ {
rw := send()
if rw.Code != http.StatusUnauthorized {
t.Fatalf("pre-throttle iteration %d: status=%d, want 401", i, rw.Code)
}
}
// 4th request: throttled.
rw := send()
if rw.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429 after threshold, got %d", rw.Code)
}
if ra := rw.Header().Get("Retry-After"); ra == "" {
t.Fatalf("expected Retry-After header on 429")
}
}
func TestServeHTTP_Bearer_ExcludedURLStripsAuth(t *testing.T) {
t.Parallel()
var capturedAuth string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
oidc.excludedURLs = map[string]struct{}{"/favicon.ico": {}}
req := httptest.NewRequest("GET", "/favicon.ico", nil)
req.Header.Set("Authorization", "Bearer abc.def.ghi")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusOK {
t.Fatalf("excluded path should pass; got %d", rw.Code)
}
if capturedAuth != "" {
t.Fatalf("Authorization must be stripped on excluded paths, got=%q", capturedAuth)
}
}
func TestServeHTTP_Bearer_RolesGate(t *testing.T) {
t.Parallel()
cases := []struct {
name string
rolesClaim []interface{}
want int
}{
{name: "matching role", rolesClaim: []interface{}{"admin"}, want: http.StatusOK},
{name: "no matching role", rolesClaim: []interface{}{"viewer"}, want: http.StatusForbidden},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
oidc.allowedRolesAndGroups = map[string]struct{}{"admin": {}}
oidc.roleClaimName = "roles"
claims := defaultBearerClaims()
claims["roles"] = tc.rolesClaim
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != tc.want {
t.Fatalf("status=%d, want %d", rw.Code, tc.want)
}
})
}
}
func TestServeHTTP_Bearer_CookieWinsByDefault(t *testing.T) {
t.Parallel()
// Both cookie and bearer present: cookie path runs (which will redirect
// to /authorize since the cookie is empty/unauthenticated).
var nextCalled atomic.Bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled.Store(true)
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
prefix := oidc.sessionManager.GetCookiePrefix()
req.AddCookie(&http.Cookie{Name: prefix + "main", Value: "irrelevant"})
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Cookie path consumed the request; bearer was ignored. Since the
// cookie is empty, the cookie path will either 302 to /authorize or
// return 401 — in either case, next must NOT be called.
if nextCalled.Load() {
t.Fatalf("next must not be called when bearer is ignored due to cookie precedence")
}
}
func TestServeHTTP_Bearer_BearerOverridesCookie(t *testing.T) {
t.Parallel()
var nextCalled atomic.Bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled.Store(true)
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
oidc.bearerOverridesCookie = true
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
prefix := oidc.sessionManager.GetCookiePrefix()
req.AddCookie(&http.Cookie{Name: prefix + "main", Value: "irrelevant"})
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled.Load() || rw.Code != http.StatusOK {
t.Fatalf("expected bearer to win with override; status=%d called=%v", rw.Code, nextCalled.Load())
}
}
func TestServeHTTP_Bearer_OversizedToken(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for oversized token")
})
oidc := makeBearerOIDC(t, next)
huge := strings.Repeat("a", AccessTokenConfig.MaxLength+1)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+huge)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_MalformedJWT(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for malformed JWT")
})
oidc := makeBearerOIDC(t, next)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer not.jwt") // 1 dot
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_FeatureOffPassesThrough(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Should not be reached: cookie path runs and (with no session)
// will redirect or 401. We assert no panic / next not called.
t.Fatalf("next must not run when bearer is off and no valid session exists")
})
oidc := makeBearerOIDC(t, next)
oidc.enableBearerAuth = false
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Expect non-200: either 302 to /authorize or 401. The point is the
// bearer pipeline didn't run.
if rw.Code == http.StatusOK {
t.Fatalf("expected non-200 when bearer is off; got %d", rw.Code)
}
}
// =============================================================================
// Startup validation tests
// =============================================================================
func TestStartupValidation_BearerRequiresAudience(t *testing.T) {
t.Parallel()
cfg := CreateConfig()
cfg.ProviderURL = "https://issuer.example.com"
cfg.ClientID = "id"
cfg.ClientSecret = "secret"
cfg.CallbackURL = "/oauth/callback"
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef"
cfg.EnableBearerAuth = true
cfg.Audience = ""
_, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), cfg, "bearer-test")
if err == nil || !strings.Contains(err.Error(), "requires Audience") {
t.Fatalf("expected audience-required error, got %v", err)
}
}
func TestStartupValidation_BearerRejectsEmailIdentifier(t *testing.T) {
t.Parallel()
cfg := CreateConfig()
cfg.ProviderURL = "https://issuer.example.com"
cfg.ClientID = "id"
cfg.ClientSecret = "secret"
cfg.CallbackURL = "/oauth/callback"
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef"
cfg.EnableBearerAuth = true
cfg.Audience = "https://api.example.com"
cfg.BearerIdentifierClaim = "email"
_, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), cfg, "bearer-test")
if err == nil || !strings.Contains(err.Error(), "bearerIdentifierClaim=\"email\"") {
t.Fatalf("expected email-identifier rejection, got %v", err)
}
}
// =============================================================================
// Principal invariants
// =============================================================================
func TestBuildPrincipalFromSession_NoIdentifier(t *testing.T) {
t.Parallel()
oidc := &TraefikOidc{logger: NewLogger("error")}
if p := oidc.buildPrincipalFromSession(nil); p != nil {
t.Fatalf("nil session must produce nil principal")
}
}
+137
View File
@@ -0,0 +1,137 @@
package traefikoidc
import (
"encoding/pem"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
// testCertPEM returns a valid PEM-encoded certificate harvested from an
// httptest.NewTLSServer. Using httptest keeps the test free of any
// handwritten static cert that could expire.
func testCertPEM(t *testing.T) string {
t.Helper()
srv := httptest.NewTLSServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
t.Cleanup(srv.Close)
cert := srv.Certificate()
if cert == nil {
t.Fatal("httptest.NewTLSServer did not expose a certificate")
}
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
}
func TestLoadCACertPool_Empty(t *testing.T) {
cfg := &Config{}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool != nil {
t.Errorf("expected nil pool when no CA source configured, got %v", pool)
}
}
func TestLoadCACertPool_InlinePEM(t *testing.T) {
cfg := &Config{CACertPEM: testCertPEM(t)}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool for valid CACertPEM")
}
}
func TestLoadCACertPool_InlinePEM_Garbage(t *testing.T) {
cfg := &Config{CACertPEM: "not a pem"}
pool, err := cfg.loadCACertPool()
if err == nil {
t.Fatal("expected error for garbage CACertPEM, got nil")
}
if pool != nil {
t.Errorf("expected nil pool on error, got %v", pool)
}
if !strings.Contains(err.Error(), "caCertPEM") {
t.Errorf("error should name the failing field, got: %v", err)
}
}
func TestLoadCACertPool_FilePath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "ca.pem")
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
t.Fatalf("writing temp PEM: %v", err)
}
cfg := &Config{CACertPath: path}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool for valid CACertPath")
}
}
func TestLoadCACertPool_FilePath_Missing(t *testing.T) {
cfg := &Config{CACertPath: "/does/not/exist/ca.pem"}
pool, err := cfg.loadCACertPool()
if err == nil {
t.Fatal("expected error for missing CACertPath, got nil")
}
if pool != nil {
t.Errorf("expected nil pool on error, got %v", pool)
}
}
func TestLoadCACertPool_Combined(t *testing.T) {
// Both inline and file sources populated — certificates from both should
// be accepted into the same pool.
dir := t.TempDir()
path := filepath.Join(dir, "ca.pem")
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
t.Fatalf("writing temp PEM: %v", err)
}
cfg := &Config{CACertPath: path, CACertPEM: testCertPEM(t)}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool when both sources set")
}
}
func TestSharedTransportPool_ConfigKeyDistinguishesCAAndSkipVerify(t *testing.T) {
p := GetGlobalTransportPool()
cfgSystem := DefaultHTTPClientConfig()
cfgSkip := DefaultHTTPClientConfig()
cfgSkip.InsecureSkipVerify = true
cfgCustomCA := DefaultHTTPClientConfig()
pool, err := (&Config{CACertPEM: testCertPEM(t)}).loadCACertPool()
if err != nil {
t.Fatalf("loadCACertPool: %v", err)
}
cfgCustomCA.RootCAs = pool
keys := map[string]string{
"system": p.configKey(cfgSystem),
"skip": p.configKey(cfgSkip),
"customCA": p.configKey(cfgCustomCA),
}
seen := make(map[string]string, len(keys))
for name, key := range keys {
if dup, ok := seen[key]; ok {
t.Errorf("configKey collision: %s and %s share key %q", name, dup, key)
}
seen[key] = name
}
}
+241
View File
@@ -0,0 +1,241 @@
package traefikoidc
import (
"fmt"
"sync"
"testing"
"time"
)
// =============================================================================
// UNIVERSAL CACHE BENCHMARKS
// =============================================================================
func BenchmarkCacheSet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
i++
}
})
}
func BenchmarkCacheGet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
for i := 0; i < 1000; i++ {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
cache.Get(fmt.Sprintf("key%d", i%1000))
i++
}
})
}
func BenchmarkCacheSetGet(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key%d", i)
cache.Set(key, fmt.Sprintf("value%d", i), 1*time.Hour)
cache.Get(key)
i++
}
})
}
func BenchmarkCacheLRUEviction(b *testing.B) {
config := createTestCacheConfig()
config.MaxSize = 100
cache := NewUniversalCache(config)
defer cache.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
}
}
func BenchmarkCacheConcurrent(b *testing.B) {
cache := NewUniversalCache(createTestCacheConfig())
defer cache.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
switch i % 3 {
case 0:
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
case 1:
cache.Get(fmt.Sprintf("key%d", i))
case 2:
cache.Delete(fmt.Sprintf("key%d", i))
}
i++
}
})
}
// =============================================================================
// CACHE MANAGER BENCHMARKS
// =============================================================================
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set("benchmark-key", "benchmark-value", time.Hour)
}
}
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
cache.Set("benchmark-key", "benchmark-value", time.Hour)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get("benchmark-key")
}
}
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
t := &testing.T{}
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
key := fmt.Sprintf("benchmark-key-%d", i)
cache.Set(key, "value", time.Hour)
b.StartTimer()
cache.Delete(key)
}
}
// =============================================================================
// CACHE COMPATIBILITY BENCHMARKS
// =============================================================================
func BenchmarkNewBoundedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewBoundedCache(1000)
}
}
func BenchmarkNewOptimizedCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewOptimizedCache()
}
}
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
strategy := NewLRUStrategy(1000)
item := "test-item"
b.ResetTimer()
for i := 0; i < b.N; i++ {
strategy.EstimateSize(item)
}
}
// =============================================================================
// SHARDED CACHE BENCHMARKS
// =============================================================================
func BenchmarkShardedCache(b *testing.B) {
b.Run("Set", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
}
})
b.Run("Get", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
for i := 0; i < 10000; i++ {
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get(fmt.Sprintf("key-%d", i%10000))
}
})
b.Run("ParallelSetGet", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key-%d", i)
cache.Set(key, i, 5*time.Minute)
cache.Get(key)
i++
}
})
})
}
// BenchmarkShardedVsGlobalMutex compares sharded cache with global mutex approach
func BenchmarkShardedVsGlobalMutex(b *testing.B) {
b.Run("ShardedCache64", func(b *testing.B) {
cache := NewShardedCache(64, 100000)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("jti-%d", i%10000)
if !cache.Exists(key) {
cache.Set(key, true, 5*time.Minute)
}
i++
}
})
})
b.Run("GlobalMutexCache", func(b *testing.B) {
var mu sync.RWMutex
data := make(map[string]bool)
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("jti-%d", i%10000)
mu.RLock()
_, exists := data[key]
mu.RUnlock()
if !exists {
mu.Lock()
data[key] = true
mu.Unlock()
}
i++
}
})
})
}
+3 -3
View File
@@ -155,9 +155,9 @@ type CacheStrategy interface {
// CacheEntry for backward compatibility
type CacheEntry struct {
Key string
Value interface{}
ExpiresAt time.Time
Value interface{}
Key string
}
// Cache is an alias for backward compatibility
@@ -175,10 +175,10 @@ func NewOptimizedCacheWithConfig(config OptimizedCacheConfig) *CacheInterfaceWra
// ListNode for backward compatibility
type ListNode struct {
Key string
Value interface{}
Next *ListNode
Prev *ListNode
Key string
}
// NewFixedMetadataCache creates a metadata cache with fixed configuration
File diff suppressed because it is too large Load Diff
+75 -8
View File
@@ -20,11 +20,39 @@ var (
cacheManagerInitOnce sync.Once
)
// GetGlobalCacheManager returns a singleton CacheManager instance
// GetGlobalCacheManager returns a singleton CacheManager instance.
//
// Deprecated: Use GetGlobalCacheManagerWithConfig instead.
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
return GetGlobalCacheManagerWithConfig(wg, nil)
}
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
cacheManagerInitOnce.Do(func() {
var redisConfig *RedisConfig
var logger *Logger
if config != nil {
logger = NewLogger(config.LogLevel)
// Initialize Redis config if not present
if config.Redis == nil {
config.Redis = &RedisConfig{}
}
// Apply environment variable fallbacks for fields not set in config
// This allows env vars to be used as optional overrides
config.Redis.ApplyEnvFallbacks()
// Apply defaults after env fallbacks
config.Redis.ApplyDefaults()
redisConfig = config.Redis
}
globalCacheManagerInstance = &CacheManager{
manager: GetUniversalCacheManager(nil),
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
}
})
return globalCacheManagerInstance
@@ -34,7 +62,7 @@ func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()}
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true}
}
// GetSharedTokenCache returns the shared token cache
@@ -61,6 +89,38 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
return &JWKCache{cache: cm.manager.GetJWKCache()}
}
// GetSharedIntrospectionCache returns the shared token introspection cache
// for caching OAuth 2.0 Token Introspection (RFC 7662) results
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true}
}
// GetSharedTokenTypeCache returns the shared token type cache
// for caching token type detection results to improve performance
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
}
// GetSharedSessionInvalidationCache returns the shared session invalidation cache
// for backchannel and front-channel logout (IdP-initiated logout)
func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
}
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
// by the refresh path to coalesce grants across Traefik replicas via Redis.
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
}
// Close gracefully shuts down all cache components
func (cm *CacheManager) Close() error {
cm.mu.Lock()
@@ -78,12 +138,13 @@ func CleanupGlobalCacheManager() error {
// CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface
type CacheInterfaceWrapper struct {
cache *UniversalCache
cache *UniversalCache
managed bool // If true, cache is managed globally and Close() is a no-op
}
// Set stores a value
func (c *CacheInterfaceWrapper) Set(key string, value interface{}, ttl time.Duration) {
c.cache.Set(key, value, ttl)
_ = c.cache.Set(key, value, ttl) // Safe to ignore: cache set failures are non-critical
}
// Get retrieves a value
@@ -106,11 +167,17 @@ func (c *CacheInterfaceWrapper) Cleanup() {
c.cache.Cleanup()
}
// Close shuts down the cache
// Close shuts down the cache if it's not managed globally.
// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding
// when multiple plugin instances are closed during Traefik configuration reloads.
func (c *CacheInterfaceWrapper) Close() {
// Close the underlying cache to stop goroutines
if c.managed {
// Cache is managed globally by UniversalCacheManager, so we don't close it here.
return
}
// Standalone cache - close it properly to stop cleanup goroutines
if c.cache != nil {
c.cache.Close()
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
}
}
+1854
View File
File diff suppressed because it is too large Load Diff
-319
View File
@@ -1,319 +0,0 @@
// Package circuit_breaker provides circuit breaker implementation for resilience
package circuit_breaker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
)
// CircuitBreakerState represents the current state of a circuit breaker.
// The circuit breaker pattern prevents cascading failures by monitoring
// error rates and temporarily blocking requests to failing services.
type CircuitBreakerState int
// Circuit breaker states following the standard pattern:
// Closed: Normal operation, requests flow through
// Open: Circuit is tripped, requests are blocked
// HalfOpen: Testing state, limited requests allowed to test recovery
const (
// CircuitBreakerClosed allows all requests through (normal operation)
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen blocks all requests (service is failing)
CircuitBreakerOpen
// CircuitBreakerHalfOpen allows limited requests to test service recovery
CircuitBreakerHalfOpen
)
// String returns a string representation of the circuit breaker state
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// Logger interface for dependency injection
type Logger interface {
Infof(format string, args ...interface{})
Errorf(format string, args ...interface{})
Debugf(format string, args ...interface{})
}
// BaseRecoveryMechanism interface for common functionality
type BaseRecoveryMechanism interface {
RecordRequest()
RecordSuccess()
RecordFailure()
GetBaseMetrics() map[string]interface{}
LogInfo(format string, args ...interface{})
LogError(format string, args ...interface{})
LogDebug(format string, args ...interface{})
}
// CircuitBreaker implements the circuit breaker pattern for external service calls.
// It monitors failure rates and automatically opens the circuit when failures
// exceed the threshold, preventing further requests until the service recovers.
type CircuitBreaker struct {
// baseRecovery provides common functionality
baseRecovery BaseRecoveryMechanism
// maxFailures is the threshold for opening the circuit
maxFailures int
// timeout is how long to wait before allowing requests in half-open state
timeout time.Duration
// resetTimeout is how long to wait before transitioning from open to half-open
resetTimeout time.Duration
// state tracks the current circuit breaker state
state CircuitBreakerState
// failures counts consecutive failures
failures int64
// lastFailureTime records when the last failure occurred
lastFailureTime time.Time
// mutex protects shared state
mutex sync.RWMutex
// logger for debugging and monitoring
logger Logger
}
// CircuitBreakerConfig holds configuration parameters for circuit breakers.
// These settings control when the circuit opens and how it recovers.
type CircuitBreakerConfig struct {
// MaxFailures is the number of failures before opening the circuit
MaxFailures int `json:"max_failures"`
// Timeout is how long to wait before trying to recover (open -> half-open)
Timeout time.Duration `json:"timeout"`
// ResetTimeout is how long to wait before fully closing the circuit
ResetTimeout time.Duration `json:"reset_timeout"`
}
// DefaultCircuitBreakerConfig returns sensible default configuration for circuit breakers.
// Configured for typical web service scenarios with moderate tolerance for failures.
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
return CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 60 * time.Second,
ResetTimeout: 30 * time.Second,
}
}
// NewCircuitBreaker creates a new circuit breaker with the specified configuration.
// The circuit breaker starts in the closed state, allowing all requests through.
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger, baseRecovery BaseRecoveryMechanism) *CircuitBreaker {
return &CircuitBreaker{
baseRecovery: baseRecovery,
maxFailures: config.MaxFailures,
timeout: config.Timeout,
resetTimeout: config.ResetTimeout,
state: CircuitBreakerClosed,
logger: logger,
}
}
// ExecuteWithContext executes a function through the circuit breaker with context.
// It checks if requests are allowed, executes the function, and updates the circuit state
// based on the result. Implements the ErrorRecoveryMechanism interface.
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
if cb.baseRecovery != nil {
cb.baseRecovery.RecordRequest()
}
if !cb.allowRequest() {
return fmt.Errorf("circuit breaker is open")
}
err := fn()
if err != nil {
cb.recordFailure()
if cb.baseRecovery != nil {
cb.baseRecovery.RecordFailure()
}
return err
}
cb.recordSuccess()
if cb.baseRecovery != nil {
cb.baseRecovery.RecordSuccess()
}
return nil
}
// Execute executes a function through the circuit breaker without context.
// This is provided for backward compatibility with existing code.
func (cb *CircuitBreaker) Execute(fn func() error) error {
return cb.ExecuteWithContext(context.Background(), fn)
}
// allowRequest determines whether to allow a request based on the circuit state.
// Handles state transitions from open to half-open based on timeout.
func (cb *CircuitBreaker) allowRequest() bool {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
switch cb.state {
case CircuitBreakerClosed:
return true
case CircuitBreakerOpen:
if now.Sub(cb.lastFailureTime) > cb.timeout {
cb.state = CircuitBreakerHalfOpen
if cb.logger != nil {
cb.logger.Infof("Circuit breaker transitioning to half-open state")
}
return true
}
return false
case CircuitBreakerHalfOpen:
return true
default:
return false
}
}
// recordFailure records a failure and potentially opens the circuit.
// Updates failure count and triggers state transitions when thresholds are exceeded.
func (cb *CircuitBreaker) recordFailure() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.failures++
cb.lastFailureTime = time.Now()
switch cb.state {
case CircuitBreakerClosed:
if cb.failures >= int64(cb.maxFailures) {
cb.state = CircuitBreakerOpen
if cb.baseRecovery != nil {
cb.baseRecovery.LogError("Circuit breaker opened after %d failures", cb.failures)
}
}
case CircuitBreakerHalfOpen:
cb.state = CircuitBreakerOpen
if cb.baseRecovery != nil {
cb.baseRecovery.LogError("Circuit breaker returned to open state after failure in half-open")
}
}
}
// recordSuccess records a successful request and potentially closes the circuit.
// Resets failure count and transitions from half-open to closed state on success.
func (cb *CircuitBreaker) recordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
switch cb.state {
case CircuitBreakerHalfOpen:
cb.failures = 0
cb.state = CircuitBreakerClosed
if cb.baseRecovery != nil {
cb.baseRecovery.LogInfo("Circuit breaker closed after successful request in half-open state")
}
case CircuitBreakerClosed:
cb.failures = 0
}
}
// GetState returns the current state of the circuit breaker.
// Thread-safe method for monitoring circuit breaker status.
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state
}
// Reset resets the circuit breaker to its initial closed state.
// Clears failure count and state, effectively recovering from any open state.
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.state = CircuitBreakerClosed
atomic.StoreInt64(&cb.failures, 0)
if cb.baseRecovery != nil {
cb.baseRecovery.LogInfo("Circuit breaker has been reset")
}
}
// IsAvailable returns whether the circuit breaker is currently allowing requests.
// This provides a quick way to check if the service is available.
func (cb *CircuitBreaker) IsAvailable() bool {
return cb.allowRequest()
}
// GetMetrics returns comprehensive metrics about the circuit breaker.
// Includes state information, failure counts, configuration, and base metrics.
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
cb.mutex.RLock()
state := cb.state
failures := cb.failures
lastFailureTime := cb.lastFailureTime
cb.mutex.RUnlock()
var metrics map[string]interface{}
if cb.baseRecovery != nil {
metrics = cb.baseRecovery.GetBaseMetrics()
} else {
metrics = make(map[string]interface{})
}
metrics["state"] = state.String()
metrics["current_failures"] = failures
metrics["max_failures"] = cb.maxFailures
metrics["timeout"] = cb.timeout.String()
metrics["reset_timeout"] = cb.resetTimeout.String()
if !lastFailureTime.IsZero() {
metrics["last_failure_time"] = lastFailureTime
metrics["time_since_last_failure"] = time.Since(lastFailureTime).String()
}
return metrics
}
// GetFailureCount returns the current failure count
func (cb *CircuitBreaker) GetFailureCount() int64 {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.failures
}
// GetLastFailureTime returns the time of the last failure
func (cb *CircuitBreaker) GetLastFailureTime() time.Time {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.lastFailureTime
}
// IsOpen returns true if the circuit breaker is in open state
func (cb *CircuitBreaker) IsOpen() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerOpen
}
// IsClosed returns true if the circuit breaker is in closed state
func (cb *CircuitBreaker) IsClosed() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerClosed
}
// IsHalfOpen returns true if the circuit breaker is in half-open state
func (cb *CircuitBreaker) IsHalfOpen() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state == CircuitBreakerHalfOpen
}
-981
View File
@@ -1,981 +0,0 @@
package circuit_breaker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
// Mock implementations for testing
type mockLogger struct {
infoLogs []string
errorLogs []string
debugLogs []string
mu sync.RWMutex
}
func (m *mockLogger) Infof(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) Errorf(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) Debugf(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
}
func (m *mockLogger) getInfoLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.infoLogs))
copy(result, m.infoLogs)
return result
}
//lint:ignore U1000 May be needed for future error log verification tests
func (m *mockLogger) getErrorLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.errorLogs))
copy(result, m.errorLogs)
return result
}
//lint:ignore U1000 May be needed for future test isolation
func (m *mockLogger) reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = nil
m.errorLogs = nil
m.debugLogs = nil
}
type mockBaseRecoveryMechanism struct {
requestCount int64
successCount int64
failureCount int64
infoLogs []string
errorLogs []string
debugLogs []string
baseMetrics map[string]interface{}
mu sync.RWMutex
}
func newMockBaseRecovery() *mockBaseRecoveryMechanism {
return &mockBaseRecoveryMechanism{
baseMetrics: make(map[string]interface{}),
}
}
func (m *mockBaseRecoveryMechanism) RecordRequest() {
atomic.AddInt64(&m.requestCount, 1)
}
func (m *mockBaseRecoveryMechanism) RecordSuccess() {
atomic.AddInt64(&m.successCount, 1)
}
func (m *mockBaseRecoveryMechanism) RecordFailure() {
atomic.AddInt64(&m.failureCount, 1)
}
func (m *mockBaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]interface{})
for k, v := range m.baseMetrics {
result[k] = v
}
result["total_requests"] = atomic.LoadInt64(&m.requestCount)
result["total_successes"] = atomic.LoadInt64(&m.successCount)
result["total_failures"] = atomic.LoadInt64(&m.failureCount)
return result
}
func (m *mockBaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) LogError(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
}
func (m *mockBaseRecoveryMechanism) getRequestCount() int64 {
return atomic.LoadInt64(&m.requestCount)
}
func (m *mockBaseRecoveryMechanism) getSuccessCount() int64 {
return atomic.LoadInt64(&m.successCount)
}
func (m *mockBaseRecoveryMechanism) getFailureCount() int64 {
return atomic.LoadInt64(&m.failureCount)
}
func (m *mockBaseRecoveryMechanism) getInfoLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.infoLogs))
copy(result, m.infoLogs)
return result
}
func (m *mockBaseRecoveryMechanism) getErrorLogs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]string, len(m.errorLogs))
copy(result, m.errorLogs)
return result
}
func TestCircuitBreakerState_String(t *testing.T) {
tests := []struct {
state CircuitBreakerState
expected string
}{
{CircuitBreakerClosed, "closed"},
{CircuitBreakerOpen, "open"},
{CircuitBreakerHalfOpen, "half-open"},
{CircuitBreakerState(999), "unknown"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := tt.state.String()
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
func TestDefaultCircuitBreakerConfig(t *testing.T) {
config := DefaultCircuitBreakerConfig()
if config.MaxFailures != 2 {
t.Errorf("Expected MaxFailures to be 2, got %d", config.MaxFailures)
}
if config.Timeout != 60*time.Second {
t.Errorf("Expected Timeout to be 60s, got %v", config.Timeout)
}
if config.ResetTimeout != 30*time.Second {
t.Errorf("Expected ResetTimeout to be 30s, got %v", config.ResetTimeout)
}
}
func TestNewCircuitBreaker(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 30 * time.Second,
ResetTimeout: 15 * time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
if cb == nil {
t.Fatal("NewCircuitBreaker returned nil")
}
if cb.maxFailures != 3 {
t.Errorf("Expected maxFailures to be 3, got %d", cb.maxFailures)
}
if cb.timeout != 30*time.Second {
t.Errorf("Expected timeout to be 30s, got %v", cb.timeout)
}
if cb.resetTimeout != 15*time.Second {
t.Errorf("Expected resetTimeout to be 15s, got %v", cb.resetTimeout)
}
if cb.state != CircuitBreakerClosed {
t.Errorf("Expected initial state to be Closed, got %v", cb.state)
}
if cb.logger != logger {
t.Error("Expected logger to be set")
}
if cb.baseRecovery != baseRecovery {
t.Error("Expected baseRecovery to be set")
}
}
func TestCircuitBreaker_ExecuteWithContext_Success(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
ctx := context.Background()
err := cb.ExecuteWithContext(ctx, testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed, got %v", cb.GetState())
}
if baseRecovery.getRequestCount() != 1 {
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
}
if baseRecovery.getSuccessCount() != 1 {
t.Errorf("Expected 1 success recorded, got %d", baseRecovery.getSuccessCount())
}
}
func TestCircuitBreaker_ExecuteWithContext_Failure(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
err := cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed after single failure, got %v", cb.GetState())
}
if baseRecovery.getRequestCount() != 1 {
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
}
if baseRecovery.getFailureCount() != 1 {
t.Errorf("Expected 1 failure recorded, got %d", baseRecovery.getFailureCount())
}
}
func TestCircuitBreaker_Execute(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err := cb.Execute(testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
}
func TestCircuitBreaker_OpenAfterMaxFailures(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
// First failure
err := cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error on first failure, got %v", err)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to remain Closed after first failure, got %v", cb.GetState())
}
// Second failure - should open circuit
err = cb.ExecuteWithContext(ctx, testFunc)
if err != testError {
t.Errorf("Expected test error on second failure, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open after max failures, got %v", cb.GetState())
}
// Third attempt - should be blocked
callCount := 0
blockedFunc := func() error {
callCount++
return nil
}
err = cb.ExecuteWithContext(ctx, blockedFunc)
if err == nil {
t.Error("Expected error when circuit is open")
}
if callCount != 0 {
t.Errorf("Expected function not to be called when circuit is open, got %d calls", callCount)
}
}
func TestCircuitBreaker_HalfOpenTransition(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond, // Very short for testing
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
err := cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Next request should transition to half-open
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err = cb.ExecuteWithContext(context.Background(), testFunc)
if err != nil {
t.Errorf("Expected no error in half-open state, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called in half-open state, got %d calls", callCount)
}
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed after successful half-open request, got %v", cb.GetState())
}
}
func TestCircuitBreaker_HalfOpenFailureReturnsToOpen(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
_ = cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Wait for timeout to allow half-open transition
time.Sleep(15 * time.Millisecond)
// First call should transition to half-open, but we'll force it by checking allowRequest
if !cb.allowRequest() {
t.Error("Expected allowRequest to return true after timeout")
}
if cb.GetState() != CircuitBreakerHalfOpen {
t.Errorf("Expected state to be HalfOpen, got %v", cb.GetState())
}
// Failure in half-open should return to open
err := cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if err != testError {
t.Errorf("Expected test error, got %v", err)
}
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to return to Open after half-open failure, got %v", cb.GetState())
}
}
func TestCircuitBreaker_Reset(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
_ = cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
if cb.GetState() != CircuitBreakerOpen {
t.Errorf("Expected state to be Open, got %v", cb.GetState())
}
// Reset circuit
cb.Reset()
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed after reset, got %v", cb.GetState())
}
if cb.GetFailureCount() != 0 {
t.Errorf("Expected failure count to be 0 after reset, got %d", cb.GetFailureCount())
}
// Should allow requests again
callCount := 0
err := cb.ExecuteWithContext(context.Background(), func() error {
callCount++
return nil
})
if err != nil {
t.Errorf("Expected no error after reset, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called after reset, got %d calls", callCount)
}
}
func TestCircuitBreaker_IsAvailable(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially available
if !cb.IsAvailable() {
t.Error("Expected circuit breaker to be available initially")
}
// Trigger opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Should not be available when open
if cb.IsAvailable() {
t.Error("Expected circuit breaker to be unavailable when open")
}
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Should be available again after timeout (half-open)
if !cb.IsAvailable() {
t.Error("Expected circuit breaker to be available after timeout")
}
}
func TestCircuitBreaker_StateCheckers(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially closed
if !cb.IsClosed() {
t.Error("Expected circuit breaker to be closed initially")
}
if cb.IsOpen() {
t.Error("Expected circuit breaker not to be open initially")
}
if cb.IsHalfOpen() {
t.Error("Expected circuit breaker not to be half-open initially")
}
// Trigger opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Should be open
if cb.IsClosed() {
t.Error("Expected circuit breaker not to be closed when open")
}
if !cb.IsOpen() {
t.Error("Expected circuit breaker to be open")
}
if cb.IsHalfOpen() {
t.Error("Expected circuit breaker not to be half-open when open")
}
// Wait for timeout and trigger half-open
time.Sleep(15 * time.Millisecond)
cb.allowRequest() // This will transition to half-open
// Should be half-open
if cb.IsClosed() {
t.Error("Expected circuit breaker not to be closed when half-open")
}
if cb.IsOpen() {
t.Error("Expected circuit breaker not to be open when half-open")
}
if !cb.IsHalfOpen() {
t.Error("Expected circuit breaker to be half-open")
}
}
func TestCircuitBreaker_GetMetrics(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 30 * time.Second,
ResetTimeout: 15 * time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
baseRecovery.baseMetrics["custom_metric"] = "custom_value"
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Record some activity
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
metrics := cb.GetMetrics()
// Check circuit breaker specific metrics
if metrics["state"] != "closed" {
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
}
if metrics["current_failures"] != int64(1) {
t.Errorf("Expected current_failures to be 1, got %v", metrics["current_failures"])
}
if metrics["max_failures"] != 2 {
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
}
if metrics["timeout"] != "30s" {
t.Errorf("Expected timeout to be '30s', got %v", metrics["timeout"])
}
if metrics["reset_timeout"] != "15s" {
t.Errorf("Expected reset_timeout to be '15s', got %v", metrics["reset_timeout"])
}
// Check base metrics are included
if metrics["total_requests"] != int64(1) {
t.Errorf("Expected total_requests to be 1, got %v", metrics["total_requests"])
}
if metrics["custom_metric"] != "custom_value" {
t.Errorf("Expected custom_metric to be 'custom_value', got %v", metrics["custom_metric"])
}
// Check failure time metrics
if _, exists := metrics["last_failure_time"]; !exists {
t.Error("Expected last_failure_time to exist")
}
if _, exists := metrics["time_since_last_failure"]; !exists {
t.Error("Expected time_since_last_failure to exist")
}
}
func TestCircuitBreaker_GetMetrics_NoBaseRecovery(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
cb := NewCircuitBreaker(config, logger, nil)
metrics := cb.GetMetrics()
// Should still have circuit breaker metrics
if metrics["state"] != "closed" {
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
}
if metrics["max_failures"] != 2 {
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
}
// Should not have base metrics
if _, exists := metrics["total_requests"]; exists {
t.Error("Expected total_requests not to exist without base recovery")
}
}
func TestCircuitBreaker_GetLastFailureTime(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Initially should be zero
if !cb.GetLastFailureTime().IsZero() {
t.Error("Expected last failure time to be zero initially")
}
// Record a failure
before := time.Now()
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
after := time.Now()
lastFailure := cb.GetLastFailureTime()
if lastFailure.IsZero() {
t.Error("Expected last failure time to be set after failure")
}
if lastFailure.Before(before) || lastFailure.After(after) {
t.Errorf("Expected last failure time to be between %v and %v, got %v",
before, after, lastFailure)
}
}
func TestCircuitBreaker_ExecuteWithoutBaseRecovery(t *testing.T) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
cb := NewCircuitBreaker(config, logger, nil)
callCount := 0
testFunc := func() error {
callCount++
return nil
}
err := cb.ExecuteWithContext(context.Background(), testFunc)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("Expected function to be called once, got %d", callCount)
}
// Should work fine without base recovery
if cb.GetState() != CircuitBreakerClosed {
t.Errorf("Expected state to be Closed, got %v", cb.GetState())
}
}
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 10, // Higher threshold for concurrent test
Timeout: 100 * time.Millisecond,
ResetTimeout: 50 * time.Millisecond,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
const numGoroutines = 10
const numOperations = 50
var wg sync.WaitGroup
successCount := int64(0)
errorCount := int64(0)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
err := cb.ExecuteWithContext(context.Background(), func() error {
// Simulate some failures
if j%10 == 9 { // Every 10th operation fails
return fmt.Errorf("simulated error")
}
return nil
})
if err != nil {
atomic.AddInt64(&errorCount, 1)
} else {
atomic.AddInt64(&successCount, 1)
}
// Intermittently check state and metrics
if j%5 == 0 {
cb.GetState()
cb.GetMetrics()
cb.IsAvailable()
}
}
}(i)
}
wg.Wait()
// Verify we got both successes and errors
finalSuccessCount := atomic.LoadInt64(&successCount)
finalErrorCount := atomic.LoadInt64(&errorCount)
if finalSuccessCount == 0 {
t.Error("Expected some successful operations")
}
if finalErrorCount == 0 {
t.Error("Expected some failed operations")
}
totalOperations := finalSuccessCount + finalErrorCount
expectedMax := int64(numGoroutines * numOperations)
if totalOperations > expectedMax {
t.Errorf("Expected at most %d operations, got %d", expectedMax, totalOperations)
}
t.Logf("Concurrent test completed: %d successes, %d errors, final state: %v",
finalSuccessCount, finalErrorCount, cb.GetState())
}
func TestCircuitBreaker_StateTransitionLogging(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Trigger circuit opening
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Check that error was logged when circuit opened
errorLogs := baseRecovery.getErrorLogs()
if len(errorLogs) == 0 {
t.Error("Expected error log when circuit breaker opened")
} else {
if !contains(errorLogs, "Circuit breaker opened after") {
t.Errorf("Expected circuit opening log, got %v", errorLogs)
}
}
// Wait and trigger half-open
time.Sleep(15 * time.Millisecond)
// Successful request should close circuit and log
cb.ExecuteWithContext(context.Background(), func() error {
return nil
})
// Check that success was logged when circuit closed
infoLogs := baseRecovery.getInfoLogs()
if len(infoLogs) == 0 {
t.Error("Expected info log when circuit breaker closed")
} else {
if !contains(infoLogs, "Circuit breaker closed after successful request") {
t.Errorf("Expected circuit closing log, got %v", infoLogs)
}
}
// Reset should also be logged
cb.Reset()
infoLogs = baseRecovery.getInfoLogs()
if !contains(infoLogs, "Circuit breaker has been reset") {
t.Errorf("Expected reset log, got %v", infoLogs)
}
}
func TestCircuitBreaker_LoggerTransitionLogging(t *testing.T) {
config := CircuitBreakerConfig{
MaxFailures: 1,
Timeout: 10 * time.Millisecond,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Wait for timeout and check half-open transition logging
testError := fmt.Errorf("test error")
cb.ExecuteWithContext(context.Background(), func() error {
return testError
})
// Wait for timeout
time.Sleep(15 * time.Millisecond)
// Next allowRequest call should log transition to half-open
cb.allowRequest()
infoLogs := logger.getInfoLogs()
if len(infoLogs) == 0 {
t.Error("Expected info log for half-open transition")
} else {
if !contains(infoLogs, "Circuit breaker transitioning to half-open state") {
t.Errorf("Expected half-open transition log, got %v", infoLogs)
}
}
}
// Helper function to check if a slice contains a string with substring
func contains(slice []string, substr string) bool {
for _, s := range slice {
if len(s) >= len(substr) && s[:len(substr)] == substr {
return true
}
}
return false
}
// Benchmark tests
func BenchmarkCircuitBreaker_ExecuteWithContext_Success(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testFunc := func() error {
return nil
}
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cb.ExecuteWithContext(ctx, testFunc)
}
})
}
func BenchmarkCircuitBreaker_ExecuteWithContext_Failure(b *testing.B) {
config := CircuitBreakerConfig{
MaxFailures: 1000, // High threshold to avoid opening during benchmark
Timeout: time.Second,
ResetTimeout: time.Second,
}
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
testError := fmt.Errorf("test error")
testFunc := func() error {
return testError
}
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.ExecuteWithContext(ctx, testFunc)
}
}
func BenchmarkCircuitBreaker_GetState(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cb.GetState()
}
})
}
func BenchmarkCircuitBreaker_GetMetrics(b *testing.B) {
config := DefaultCircuitBreakerConfig()
logger := &mockLogger{}
baseRecovery := newMockBaseRecovery()
cb := NewCircuitBreaker(config, logger, baseRecovery)
// Add some activity
for i := 0; i < 100; i++ {
if i%2 == 0 {
cb.ExecuteWithContext(context.Background(), func() error { return nil })
} else {
cb.ExecuteWithContext(context.Background(), func() error { return fmt.Errorf("error") })
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.GetMetrics()
}
}
+295
View File
@@ -0,0 +1,295 @@
package traefikoidc
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"os"
"time"
)
// isSupportedClientAssertionAlg reports whether alg is a recognized JWS
// algorithm for private_key_jwt (RFC 7523 §2.2).
func isSupportedClientAssertionAlg(alg string) bool {
switch alg {
case "RS256", "RS384", "RS512",
"PS256", "PS384", "PS512",
"ES256", "ES384", "ES512":
return true
}
return false
}
// ClientAssertionSigner builds and signs client_assertion JWTs (RFC 7523 §2.2).
type ClientAssertionSigner struct {
key crypto.PrivateKey
alg string
kid string
// rand is the entropy source for jti generation and PSS/ECDSA signing.
// Defaults to crypto/rand.Reader when nil.
rand io.Reader
// now returns the current time. Defaults to time.Now when nil.
now func() time.Time
}
// NewClientAssertionSigner parses pemBytes as a private key, validates that
// alg is consistent with the key type, and returns a ready-to-use signer.
// kid is placed verbatim in the JWS header.
//
// PEM block types understood:
// - "PRIVATE KEY" → PKCS#8 (tried first for all types)
// - "RSA PRIVATE KEY" → PKCS#1
// - "EC PRIVATE KEY" → SEC1
func NewClientAssertionSigner(pemBytes []byte, alg, kid string) (*ClientAssertionSigner, error) {
if !isSupportedClientAssertionAlg(alg) {
return nil, fmt.Errorf("unsupported client assertion alg %q", alg)
}
if kid == "" {
return nil, fmt.Errorf("kid must not be empty")
}
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, fmt.Errorf("no PEM block found in private key material")
}
var key crypto.PrivateKey
var parseErr error
switch block.Type {
case "PRIVATE KEY":
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
case "RSA PRIVATE KEY":
key, parseErr = x509.ParsePKCS1PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
key, parseErr = x509.ParseECPrivateKey(block.Bytes)
default:
// Best-effort fallback for unknown block types.
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
}
if parseErr != nil {
return nil, fmt.Errorf("failed to parse private key (block type %q): %w", block.Type, parseErr)
}
if err := validateAlgKeyMatch(alg, key); err != nil {
return nil, err
}
return &ClientAssertionSigner{key: key, alg: alg, kid: kid}, nil
}
// validateAlgKeyMatch returns an error when alg implies a key type that does
// not match the actual key.
func validateAlgKeyMatch(alg string, key crypto.PrivateKey) error {
switch alg[0] {
case 'R', 'P': // RS* or PS*
if _, ok := key.(*rsa.PrivateKey); !ok {
return fmt.Errorf("alg %q requires an RSA key, got %T", alg, key)
}
case 'E': // ES*
if _, ok := key.(*ecdsa.PrivateKey); !ok {
return fmt.Errorf("alg %q requires an EC key, got %T", alg, key)
}
}
return nil
}
// Sign constructs and returns a signed client_assertion JWT.
// audience is typically the token endpoint URL (RFC 7523 §3).
// clientID is used as both iss and sub per RFC 7523 §2.2.
func (s *ClientAssertionSigner) Sign(audience, clientID string) (string, error) {
rander := s.rand
if rander == nil {
rander = rand.Reader
}
nowFn := s.now
if nowFn == nil {
nowFn = time.Now
}
now := nowFn()
// 16 random bytes as lowercase hex for jti uniqueness.
jtiBytes := make([]byte, 16)
if _, err := io.ReadFull(rander, jtiBytes); err != nil {
return "", fmt.Errorf("failed to generate jti: %w", err)
}
jti := hex.EncodeToString(jtiBytes)
header := map[string]string{
"alg": s.alg,
"typ": "JWT",
"kid": s.kid,
}
hdrJSON, err := json.Marshal(header)
if err != nil {
return "", fmt.Errorf("failed to marshal JWT header: %w", err)
}
claims := map[string]any{
"iss": clientID,
"sub": clientID,
"aud": audience,
"jti": jti,
"iat": now.Unix(),
"exp": now.Add(60 * time.Second).Unix(),
}
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal JWT claims: %w", err)
}
hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
signingInput := hdrB64 + "." + claimsB64
sig, err := s.sign(rander, []byte(signingInput))
if err != nil {
return "", err
}
return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig), nil
}
// sign computes raw signature bytes for signingInput per s.alg.
// validateAlgKeyMatch in NewClientAssertionSigner guarantees the key type
// matches s.alg, but the comma-ok asserts here keep errcheck happy and
// surface internal misuse loudly instead of via panic.
func (s *ClientAssertionSigner) sign(rander io.Reader, input []byte) ([]byte, error) {
switch s.alg {
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
rsaKey, ok := s.key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("internal: alg %q requires *rsa.PrivateKey, got %T", s.alg, s.key)
}
hash := rsaHashForAlg(s.alg)
digest := hashSum(hash, input)
if s.alg[0] == 'R' {
return signRSAPKCS1v15(rander, rsaKey, hash, digest)
}
return signRSAPSS(rander, rsaKey, hash, digest)
case "ES256", "ES384", "ES512":
ecKey, ok := s.key.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("internal: alg %q requires *ecdsa.PrivateKey, got %T", s.alg, s.key)
}
hash := ecHashForAlg(s.alg)
digest := hashSum(hash, input)
return signECDSA(rander, ecKey, digest)
}
return nil, fmt.Errorf("unhandled alg %q", s.alg)
}
func rsaHashForAlg(alg string) crypto.Hash {
switch alg {
case "RS256", "PS256":
return crypto.SHA256
case "RS384", "PS384":
return crypto.SHA384
case "RS512", "PS512":
return crypto.SHA512
}
return 0
}
func ecHashForAlg(alg string) crypto.Hash {
switch alg {
case "ES256":
return crypto.SHA256
case "ES384":
return crypto.SHA384
case "ES512":
return crypto.SHA512
}
return 0
}
func hashSum(h crypto.Hash, input []byte) []byte {
switch h {
case crypto.SHA256:
sum := sha256.Sum256(input)
return sum[:]
case crypto.SHA384:
sum := sha512.Sum384(input)
return sum[:]
case crypto.SHA512:
sum := sha512.Sum512(input)
return sum[:]
}
return nil
}
func signRSAPKCS1v15(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
sig, err := rsa.SignPKCS1v15(rander, key, hash, digest)
if err != nil {
return nil, fmt.Errorf("RSA PKCS1v15 signing failed: %w", err)
}
return sig, nil
}
func signRSAPSS(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hash}
sig, err := rsa.SignPSS(rander, key, hash, digest, opts)
if err != nil {
return nil, fmt.Errorf("RSA PSS signing failed: %w", err)
}
return sig, nil
}
// signECDSA produces the JWS raw r||s signature (RFC 7515 App. A.3).
// Each scalar is zero-padded to (curve.BitSize+7)/8 bytes.
func signECDSA(rander io.Reader, key *ecdsa.PrivateKey, digest []byte) ([]byte, error) {
r, ss, err := ecdsa.Sign(rander, key, digest)
if err != nil {
return nil, fmt.Errorf("ECDSA signing failed: %w", err)
}
byteLen := (key.Curve.Params().BitSize + 7) / 8
sig := make([]byte, 2*byteLen)
padBigInt(sig[0:byteLen], r)
padBigInt(sig[byteLen:], ss)
return sig, nil
}
// padBigInt writes n as a fixed-width big-endian integer into buf.
func padBigInt(buf []byte, n *big.Int) {
b := n.Bytes()
copy(buf[len(buf)-len(b):], b)
}
// buildClientAssertionSignerFromConfig loads key material and constructs a
// ClientAssertionSigner. Called from NewWithContext when
// ClientAuthMethod == "private_key_jwt".
func buildClientAssertionSignerFromConfig(config *Config) (*ClientAssertionSigner, error) {
var pemBytes []byte
if config.ClientAssertionPrivateKey != "" {
pemBytes = []byte(config.ClientAssertionPrivateKey)
} else {
data, err := os.ReadFile(config.ClientAssertionKeyPath)
if err != nil {
return nil, fmt.Errorf("read clientAssertionKeyPath %q: %w", config.ClientAssertionKeyPath, err)
}
pemBytes = data
}
alg := config.ClientAssertionAlg
if alg == "" {
alg = "RS256"
}
return NewClientAssertionSigner(pemBytes, alg, config.ClientAssertionKeyID)
}
+28
View File
@@ -0,0 +1,28 @@
# syntax=docker/dockerfile:1.7
#
# This Dockerfile is consumed by GoReleaser. The binary is built outside
# the Docker context (by goreleaser's Go cross-compile) and placed in the
# build context as ./oidcgate before `docker buildx build` runs.
#
# To build locally without goreleaser:
# go build -o oidcgate ./cmd/oidcgate
# docker build -f cmd/oidcgate/Dockerfile -t oidcgate:dev .
FROM gcr.io/distroless/static-debian12:nonroot
ARG TARGETOS
ARG TARGETARCH
LABEL org.opencontainers.image.title="oidcgate"
LABEL org.opencontainers.image.description="Standalone OIDC forward-auth daemon for nginx/Caddy/Traefik/HAProxy/Envoy"
LABEL org.opencontainers.image.source="https://github.com/lukaszraczylo/traefikoidc"
LABEL org.opencontainers.image.documentation="https://github.com/lukaszraczylo/traefikoidc/blob/main/docs/OIDCGATE.md"
LABEL org.opencontainers.image.licenses="MIT"
COPY oidcgate /usr/local/bin/oidcgate
EXPOSE 8080
USER nonroot:nonroot
ENTRYPOINT ["/usr/local/bin/oidcgate"]
CMD ["--config", "/etc/oidcgate/config.yaml"]
+222
View File
@@ -0,0 +1,222 @@
package main
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"unicode"
"github.com/lukaszraczylo/traefikoidc"
"gopkg.in/yaml.v3"
)
// Config is the top-level oidcgate configuration. The OIDC subtree maps 1:1
// onto traefikoidc.Config; the few extra fields configure the daemon itself.
type Config struct {
Listen string `json:"listen"`
AuthPath string `json:"authPath"`
StartPath string `json:"startPath"`
OIDC traefikoidc.Config `json:"-"`
}
// envScalarFields lists Config field names (within OIDC and top-level)
// that may be overridden via OIDCGATE_<UPPER_SNAKE_CASE> environment
// variables. Only scalar strings/ints/bools are supported; nested structs
// (Redis, SecurityHeaders, DynamicClientRegistration) stay YAML-only.
var envScalarFields = []string{
"Listen", "AuthPath", "StartPath",
"ProviderURL", "ClientID", "ClientSecret", "Audience",
"CallbackURL", "LogoutURL", "PostLogoutRedirectURI",
"SessionEncryptionKey", "CookiePrefix", "CookieDomain",
"LogLevel", "RevocationURL", "OIDCEndSessionURL",
"UserIdentifierClaim", "GroupClaimName", "RoleClaimName",
"ClientAuthMethod", "ClientAssertionPrivateKey",
"ClientAssertionKeyPath", "ClientAssertionKeyID", "ClientAssertionAlg",
"CACertPath", "CACertPEM",
}
// Load reads YAML from path, applies env-var overrides, fills defaults,
// and forces TrustForwardedURI=true so the library honors X-Forwarded-Uri.
func Load(path string) (*Config, error) {
// Clean the operator-supplied path to satisfy gosec G304 (file inclusion
// via variable). filepath.Clean strips traversal sequences and normalises
// the path; this is canonical mitigation for config files supplied via a
// CLI flag — the operator runs the daemon, so the input is trusted, but
// gosec's static analysis still flags variable paths without the cleanup.
clean := filepath.Clean(path)
data, err := os.ReadFile(clean) // #nosec G304 -- operator-supplied config path, cleaned above
if err != nil {
return nil, fmt.Errorf("read config: %w", err)
}
// Pass 1: YAML → generic map.
var raw map[string]any
if err := yaml.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("yaml parse: %w", err)
}
// Split the top-level oidcgate-specific keys away from the OIDC subtree.
listen, _ := raw["listen"].(string)
authPath, _ := raw["authPath"].(string)
startPath, _ := raw["startPath"].(string)
delete(raw, "listen")
delete(raw, "authPath")
delete(raw, "startPath")
// Pass 2: remaining map → JSON → traefikoidc.Config (uses existing json tags).
jsonBytes, err := json.Marshal(raw)
if err != nil {
return nil, fmt.Errorf("yaml→json: %w", err)
}
var oidcCfg traefikoidc.Config
if err := json.Unmarshal(jsonBytes, &oidcCfg); err != nil {
return nil, fmt.Errorf("oidc config parse: %w", err)
}
cfg := &Config{
Listen: listen,
AuthPath: authPath,
StartPath: startPath,
OIDC: oidcCfg,
}
applyEnvOverrides(cfg)
applyDefaults(cfg)
if cfg.Listen == "" {
return nil, fmt.Errorf("config: missing required 'listen' (or OIDCGATE_LISTEN env var)")
}
if !strings.HasPrefix(cfg.OIDC.CallbackURL, "/") {
return nil, fmt.Errorf("config: callbackURL must be a path starting with '/', got %q", cfg.OIDC.CallbackURL)
}
if !strings.HasPrefix(cfg.OIDC.LogoutURL, "/") {
return nil, fmt.Errorf("config: logoutURL must be a path starting with '/', got %q", cfg.OIDC.LogoutURL)
}
reserved := []string{
sentinelPath,
cfg.AuthPath,
cfg.StartPath,
cfg.OIDC.CallbackURL,
cfg.OIDC.LogoutURL,
}
for _, ex := range cfg.OIDC.ExcludedURLs {
for _, r := range reserved {
if r != "" && strings.HasPrefix(r, ex) {
return nil, fmt.Errorf("config: excludedURL %q would bypass reserved oidcgate path %q", ex, r)
}
}
}
// Force standalone semantics: trust X-Forwarded-Uri.
cfg.OIDC.TrustForwardedURI = true
return cfg, nil
}
// applyEnvOverrides walks the allow-listed scalar fields and replaces any
// non-empty OIDCGATE_<UPPER_SNAKE_CASE> env var. Field name "ClientID"
// becomes "OIDCGATE_CLIENT_ID"; "SessionEncryptionKey" becomes
// "OIDCGATE_SESSION_ENCRYPTION_KEY".
func applyEnvOverrides(cfg *Config) {
for _, field := range envScalarFields {
env := os.Getenv("OIDCGATE_" + camelToSnakeUpper(field))
if env == "" {
continue
}
setScalarField(cfg, field, env)
}
}
func setScalarField(cfg *Config, field, value string) {
switch field {
case "Listen":
cfg.Listen = value
case "AuthPath":
cfg.AuthPath = value
case "StartPath":
cfg.StartPath = value
case "ProviderURL":
cfg.OIDC.ProviderURL = value
case "ClientID":
cfg.OIDC.ClientID = value
case "ClientSecret":
cfg.OIDC.ClientSecret = value
case "Audience":
cfg.OIDC.Audience = value
case "CallbackURL":
cfg.OIDC.CallbackURL = value
case "LogoutURL":
cfg.OIDC.LogoutURL = value
case "PostLogoutRedirectURI":
cfg.OIDC.PostLogoutRedirectURI = value
case "SessionEncryptionKey":
cfg.OIDC.SessionEncryptionKey = value
case "CookiePrefix":
cfg.OIDC.CookiePrefix = value
case "CookieDomain":
cfg.OIDC.CookieDomain = value
case "LogLevel":
cfg.OIDC.LogLevel = value
case "RevocationURL":
cfg.OIDC.RevocationURL = value
case "OIDCEndSessionURL":
cfg.OIDC.OIDCEndSessionURL = value
case "UserIdentifierClaim":
cfg.OIDC.UserIdentifierClaim = value
case "GroupClaimName":
cfg.OIDC.GroupClaimName = value
case "RoleClaimName":
cfg.OIDC.RoleClaimName = value
case "ClientAuthMethod":
cfg.OIDC.ClientAuthMethod = value
case "ClientAssertionPrivateKey":
cfg.OIDC.ClientAssertionPrivateKey = value
case "ClientAssertionKeyPath":
cfg.OIDC.ClientAssertionKeyPath = value
case "ClientAssertionKeyID":
cfg.OIDC.ClientAssertionKeyID = value
case "ClientAssertionAlg":
cfg.OIDC.ClientAssertionAlg = value
case "CACertPath":
cfg.OIDC.CACertPath = value
case "CACertPEM":
cfg.OIDC.CACertPEM = value
}
}
func applyDefaults(cfg *Config) {
if cfg.AuthPath == "" {
cfg.AuthPath = "/oauth2/auth"
}
if cfg.StartPath == "" {
cfg.StartPath = "/oauth2/start"
}
}
// camelToSnakeUpper turns "ClientSecret" into "CLIENT_SECRET",
// "SessionEncryptionKey" into "SESSION_ENCRYPTION_KEY", etc.
// Multi-letter acronyms keep their grouping: "OIDCEndSessionURL" →
// "OIDC_END_SESSION_URL", "CACertPEM" → "CA_CERT_PEM".
func camelToSnakeUpper(s string) string {
runes := []rune(s)
var b strings.Builder
for i, r := range runes {
if i > 0 && isUpper(r) {
prev := runes[i-1]
next := rune(0)
if i+1 < len(runes) {
next = runes[i+1]
}
if !isUpper(prev) || (next != 0 && !isUpper(next)) {
b.WriteByte('_')
}
}
b.WriteRune(unicode.ToUpper(r))
}
return b.String()
}
func isUpper(r rune) bool { return r >= 'A' && r <= 'Z' }
+303
View File
@@ -0,0 +1,303 @@
package main
import (
"os"
"path/filepath"
"testing"
)
// minimalYAML is a base config accepted by Load with no surprises.
const minimalYAML = `
listen: ":8080"
providerURL: "https://idp.example"
clientID: "abc"
clientSecret: "secret"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
`
func writeConfig(t *testing.T, content string) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
t.Fatal(err)
}
return path
}
func TestLoad_YAMLRoundTrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, []byte(`
listen: ":9090"
authPath: "/auth"
startPath: "/start"
providerURL: "https://idp.example"
clientID: "abc"
clientSecret: "secret"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
`), 0o600); err != nil {
t.Fatal(err)
}
cfg, err := Load(path)
if err != nil {
t.Fatal(err)
}
if cfg.Listen != ":9090" {
t.Errorf("listen: want :9090, got %q", cfg.Listen)
}
if cfg.AuthPath != "/auth" {
t.Errorf("authPath: want /auth, got %q", cfg.AuthPath)
}
if cfg.StartPath != "/start" {
t.Errorf("startPath: want /start, got %q", cfg.StartPath)
}
if cfg.OIDC.ClientID != "abc" {
t.Errorf("clientID: want abc, got %q", cfg.OIDC.ClientID)
}
if cfg.OIDC.ClientSecret != "secret" {
t.Errorf("clientSecret: want secret, got %q", cfg.OIDC.ClientSecret)
}
if !cfg.OIDC.TrustForwardedURI {
t.Errorf("TrustForwardedURI should be forced true by Load")
}
}
func TestLoad_EnvOverride(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, []byte(`
listen: ":8080"
providerURL: "https://idp.example"
clientID: "abc"
clientSecret: "from-file"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
`), 0o600); err != nil {
t.Fatal(err)
}
t.Setenv("OIDCGATE_CLIENT_SECRET", "from-env")
t.Setenv("OIDCGATE_LISTEN", ":9999")
cfg, err := Load(path)
if err != nil {
t.Fatal(err)
}
if cfg.OIDC.ClientSecret != "from-env" {
t.Errorf("env override (clientSecret): want from-env, got %q", cfg.OIDC.ClientSecret)
}
if cfg.Listen != ":9999" {
t.Errorf("env override (listen): want :9999, got %q", cfg.Listen)
}
}
func TestLoad_Defaults(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, []byte(`
listen: ":8080"
providerURL: "https://idp.example"
clientID: "abc"
clientSecret: "secret"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
`), 0o600); err != nil {
t.Fatal(err)
}
cfg, err := Load(path)
if err != nil {
t.Fatal(err)
}
if cfg.AuthPath != "/oauth2/auth" {
t.Errorf("AuthPath default: want /oauth2/auth, got %q", cfg.AuthPath)
}
if cfg.StartPath != "/oauth2/start" {
t.Errorf("StartPath default: want /oauth2/start, got %q", cfg.StartPath)
}
}
func TestLoad_MissingFile(t *testing.T) {
if _, err := Load("/nonexistent/config.yaml"); err == nil {
t.Fatal("expected error for missing file")
}
}
func TestLoad_NestedStructRoundTrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "config.yaml")
if err := os.WriteFile(path, []byte(`
listen: ":8080"
providerURL: "https://idp.example"
clientID: "abc"
clientSecret: "secret"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
redis:
address: "redis:6379"
password: "redispw"
`), 0o600); err != nil {
t.Fatal(err)
}
cfg, err := Load(path)
if err != nil {
t.Fatal(err)
}
if cfg.OIDC.Redis == nil {
t.Fatal("redis block should populate cfg.OIDC.Redis")
}
if cfg.OIDC.Redis.Address != "redis:6379" {
t.Errorf("redis address: want redis:6379, got %q", cfg.OIDC.Redis.Address)
}
}
// Fix 5: callbackURL / logoutURL must start with "/"
func TestLoad_RejectsAbsoluteCallbackURL(t *testing.T) {
path := writeConfig(t, `
listen: ":8080"
providerURL: "https://idp.example"
clientID: "abc"
clientSecret: "secret"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "https://app.example.com/oauth2/callback"
logoutURL: "/oauth2/logout"
`)
if _, err := Load(path); err == nil {
t.Fatal("callbackURL with absolute URL must be rejected")
}
}
func TestLoad_RejectsAbsoluteLogoutURL(t *testing.T) {
path := writeConfig(t, `
listen: ":8080"
providerURL: "https://idp.example"
clientID: "abc"
clientSecret: "secret"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "/oauth2/callback"
logoutURL: "https://app.example.com/oauth2/logout"
`)
if _, err := Load(path); err == nil {
t.Fatal("logoutURL with absolute URL must be rejected")
}
}
// Fix 2: excludedURLs must not prefix reserved paths
func TestLoad_RejectsExcludedURLPrefixingReservedPath(t *testing.T) {
path := writeConfig(t, `
listen: ":8080"
providerURL: "https://idp.example"
clientID: "abc"
clientSecret: "secret"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
excludedURLs: ["/"]
`)
if _, err := Load(path); err == nil {
t.Fatal("excludedURLs: ['/'] must be rejected (bypasses all reserved paths)")
}
}
func TestLoad_AllowsNonOverlappingExcludedURL(t *testing.T) {
path := writeConfig(t, minimalYAML+`excludedURLs: ["/public"]
`)
if _, err := Load(path); err != nil {
t.Fatalf("non-overlapping excludedURL must be accepted: %v", err)
}
}
// Fix 3: env override coverage — every envScalarFields entry must have a
// matching case in setScalarField. isAllZeroForField detects drift.
func TestEnvOverrideCoverage(t *testing.T) {
for _, field := range envScalarFields {
field := field
t.Run(field, func(t *testing.T) {
probe := "/safe/probe-" + field
if field == "SessionEncryptionKey" {
probe = "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"
}
if field == "LogLevel" {
probe = "debug"
}
if field == "ClientAuthMethod" {
probe = "client_secret_post"
}
if field == "ClientAssertionAlg" {
probe = "RS256"
}
var fresh Config
setScalarField(&fresh, field, probe)
if isAllZeroForField(&fresh, field, probe) {
t.Fatalf("envScalarFields includes %q but setScalarField has no matching case (drift)", field)
}
})
}
}
// isAllZeroForField returns true when setScalarField did NOT set the expected
// field — i.e., the switch is missing a case for `field`.
func isAllZeroForField(cfg *Config, field, probe string) bool {
switch field {
case "Listen":
return cfg.Listen != probe
case "AuthPath":
return cfg.AuthPath != probe
case "StartPath":
return cfg.StartPath != probe
case "ProviderURL":
return cfg.OIDC.ProviderURL != probe
case "ClientID":
return cfg.OIDC.ClientID != probe
case "ClientSecret":
return cfg.OIDC.ClientSecret != probe
case "Audience":
return cfg.OIDC.Audience != probe
case "CallbackURL":
return cfg.OIDC.CallbackURL != probe
case "LogoutURL":
return cfg.OIDC.LogoutURL != probe
case "PostLogoutRedirectURI":
return cfg.OIDC.PostLogoutRedirectURI != probe
case "SessionEncryptionKey":
return cfg.OIDC.SessionEncryptionKey != probe
case "CookiePrefix":
return cfg.OIDC.CookiePrefix != probe
case "CookieDomain":
return cfg.OIDC.CookieDomain != probe
case "LogLevel":
return cfg.OIDC.LogLevel != probe
case "RevocationURL":
return cfg.OIDC.RevocationURL != probe
case "OIDCEndSessionURL":
return cfg.OIDC.OIDCEndSessionURL != probe
case "UserIdentifierClaim":
return cfg.OIDC.UserIdentifierClaim != probe
case "GroupClaimName":
return cfg.OIDC.GroupClaimName != probe
case "RoleClaimName":
return cfg.OIDC.RoleClaimName != probe
case "ClientAuthMethod":
return cfg.OIDC.ClientAuthMethod != probe
case "ClientAssertionPrivateKey":
return cfg.OIDC.ClientAssertionPrivateKey != probe
case "ClientAssertionKeyPath":
return cfg.OIDC.ClientAssertionKeyPath != probe
case "ClientAssertionKeyID":
return cfg.OIDC.ClientAssertionKeyID != probe
case "ClientAssertionAlg":
return cfg.OIDC.ClientAssertionAlg != probe
case "CACertPath":
return cfg.OIDC.CACertPath != probe
case "CACertPEM":
return cfg.OIDC.CACertPEM != probe
}
return true // unknown field → drift
}
+69
View File
@@ -0,0 +1,69 @@
package main
import "net/http"
// sentinelPath is the synthetic request path used when delegating /oauth2/auth
// and /oauth2/start into the traefikoidc middleware. It must NOT collide with
// callbackURL, logoutURL, /health*, or any plausible excludedURLs entry —
// the underscores and double-prefixing make accidental matches near-impossible.
const sentinelPath = "/__oidcgate_protected__"
// newAuthHandler builds the /oauth2/auth (silent probe) handler.
// Rewrites the request path to sentinelPath, wraps the ResponseWriter to
// convert the middleware's 302→IdP into 401, and delegates.
func newAuthHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
ic := newAuthInterceptor(rw)
defer ic.Finalize()
r2 := cloneAndRewrite(req, sentinelPath)
next.ServeHTTP(ic, r2)
})
}
// newStartHandler builds the /oauth2/start (visible sign-in) handler.
// Rewrites the path to sentinelPath, forwards any ?rd= query as
// X-Forwarded-Uri so the middleware (with TrustForwardedURI=true) captures
// the right post-login redirect target, then delegates. The middleware's
// natural 302→IdP flows through unchanged.
func newStartHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
r2 := cloneAndRewrite(req, sentinelPath)
// Precedence: explicit ?rd= wins over an ambient upstream
// X-Forwarded-Uri so /oauth2/start?rd=/dashboard does not get
// silently overridden by the proxy's current-URL forwarding.
if rd := req.URL.Query().Get("rd"); rd != "" {
r2.Header.Set("X-Forwarded-Uri", rd)
}
next.ServeHTTP(rw, r2)
})
}
// newCallbackHandler builds the IdP callback endpoint.
// Rewrites the request path to the configured callbackURL so the middleware's
// path-match at the top of ServeHTTP triggers the callback flow.
func newCallbackHandler(next http.Handler, callbackURL string) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
r2 := cloneAndRewrite(req, callbackURL)
next.ServeHTTP(rw, r2)
})
}
// newLogoutHandler builds the logout endpoint.
// Rewrites the request path to the configured logoutURL so the middleware's
// path-match at the top of ServeHTTP triggers the logout flow.
func newLogoutHandler(next http.Handler, logoutURL string) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
r2 := cloneAndRewrite(req, logoutURL)
next.ServeHTTP(rw, r2)
})
}
// cloneAndRewrite returns a clone of req with URL.Path set to newPath.
// req.Clone deep-copies URL via net/http's cloneURL, so mutating
// r2.URL.Path does not affect the original req. RawQuery, Host,
// Fragment, RawPath are preserved unchanged.
func cloneAndRewrite(req *http.Request, newPath string) *http.Request {
r2 := req.Clone(req.Context())
r2.URL.Path = newPath
return r2
}
+167
View File
@@ -0,0 +1,167 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
// stubMiddleware lets us test endpoint wiring without spinning up a full
// traefikoidc instance. Each test injects the behavior it wants.
type stubMiddleware struct {
calls []stubCall
fn func(rw http.ResponseWriter, req *http.Request)
}
type stubCall struct {
path string
header http.Header
}
func (s *stubMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.calls = append(s.calls, stubCall{path: req.URL.Path, header: req.Header.Clone()})
if s.fn != nil {
s.fn(rw, req)
}
}
func TestAuth_RewritesToSentinel_AndConverts302To401(t *testing.T) {
stub := &stubMiddleware{
fn: func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Location", "https://idp.example/authorize?state=abc")
rw.Header().Add("Set-Cookie", "_oidc_state=abc; Path=/")
rw.WriteHeader(http.StatusFound)
},
}
h := newAuthHandler(stub)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/oauth2/auth", nil)
req.Header.Set("X-Forwarded-Uri", "/protected/page")
h.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: want 401, got %d", rec.Code)
}
if len(stub.calls) != 1 || stub.calls[0].path != sentinelPath {
t.Fatalf("middleware path: want %q, got %v", sentinelPath, stub.calls)
}
if rec.Header().Get("X-Auth-Redirect") == "" {
t.Error("X-Auth-Redirect should carry Location")
}
if got := stub.calls[0].header.Get("X-Forwarded-Uri"); got != "/protected/page" {
t.Errorf("X-Forwarded-Uri must pass through to middleware: want /protected/page, got %q", got)
}
}
func TestAuth_AuthenticatedReturnsHeadersAnd200(t *testing.T) {
stub := &stubMiddleware{
fn: func(rw http.ResponseWriter, req *http.Request) {
// Middleware would stamp X-Forwarded-User on req then call next.
req.Header.Set("X-Forwarded-User", "alice")
newSuccessHandler().ServeHTTP(rw, req)
},
}
h := newAuthHandler(stub)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/oauth2/auth", nil)
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status: want 200, got %d", rec.Code)
}
if got := rec.Header().Get("X-Forwarded-User"); got != "alice" {
t.Errorf("X-Forwarded-User mirrored: want alice, got %q", got)
}
}
func TestStart_DelegatesWithSentinel_NoInterception(t *testing.T) {
stub := &stubMiddleware{
fn: func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Location", "https://idp.example/authorize")
rw.WriteHeader(http.StatusFound)
},
}
h := newStartHandler(stub)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/oauth2/start?rd=/back", nil)
h.ServeHTTP(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("start: 302 must flow through, got %d", rec.Code)
}
if stub.calls[0].path != sentinelPath {
t.Fatalf("start path rewrite: want %q, got %q", sentinelPath, stub.calls[0].path)
}
}
func TestStart_ForwardsRdAsXForwardedURI(t *testing.T) {
stub := &stubMiddleware{
fn: func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusFound) },
}
h := newStartHandler(stub)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/oauth2/start?rd=/back/here", nil)
h.ServeHTTP(rec, req)
if got := stub.calls[0].header.Get("X-Forwarded-Uri"); got != "/back/here" {
t.Fatalf("?rd should become X-Forwarded-Uri: want /back/here, got %q", got)
}
}
func TestStart_RdQueryWinsOverUpstreamHeader(t *testing.T) {
stub := &stubMiddleware{
fn: func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusFound) },
}
h := newStartHandler(stub)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/oauth2/start?rd=/explicit", nil)
req.Header.Set("X-Forwarded-Uri", "/ambient")
h.ServeHTTP(rec, req)
if got := stub.calls[0].header.Get("X-Forwarded-Uri"); got != "/explicit" {
t.Fatalf("?rd= must win over upstream X-Forwarded-Uri: want /explicit, got %q", got)
}
}
func TestCallback_RewritesToConfiguredCallbackURL(t *testing.T) {
var seenPath, seenQuery string
stub := &stubMiddleware{
fn: func(rw http.ResponseWriter, req *http.Request) {
seenPath = req.URL.Path
seenQuery = req.URL.RawQuery
rw.WriteHeader(http.StatusOK)
},
}
h := newCallbackHandler(stub, "/oauth2/callback")
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/oauth2/callback?code=abc&state=xyz", nil)
h.ServeHTTP(rec, req)
if seenPath != "/oauth2/callback" {
t.Fatalf("callback path: want /oauth2/callback, got %q", seenPath)
}
if seenQuery != "code=abc&state=xyz" {
t.Fatalf("callback query must survive rewrite: want code=abc&state=xyz, got %q", seenQuery)
}
}
func TestLogout_RewritesToConfiguredLogoutURL(t *testing.T) {
var seenPath string
stub := &stubMiddleware{
fn: func(rw http.ResponseWriter, req *http.Request) {
seenPath = req.URL.Path
rw.WriteHeader(http.StatusOK)
},
}
h := newLogoutHandler(stub, "/oauth2/logout")
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/oauth2/logout", nil)
h.ServeHTTP(rec, req)
if seenPath != "/oauth2/logout" {
t.Fatalf("logout path: want /oauth2/logout, got %q", seenPath)
}
}
+24
View File
@@ -0,0 +1,24 @@
package main
import "net/http"
// readyReporter is satisfied by *traefikoidc.TraefikOidc via its Ready() method.
type readyReporter interface {
Ready() bool
}
func newHealthzHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusOK)
})
}
func newReadyzHandler(r readyReporter) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
if r.Ready() {
rw.WriteHeader(http.StatusOK)
return
}
rw.WriteHeader(http.StatusServiceUnavailable)
})
}
+38
View File
@@ -0,0 +1,38 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
type readyStub struct{ ready bool }
func (r *readyStub) Ready() bool { return r.ready }
func TestHealthz_Always200(t *testing.T) {
h := newHealthzHandler()
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/healthz", nil))
if rec.Code != http.StatusOK {
t.Fatalf("healthz: want 200, got %d", rec.Code)
}
}
func TestReadyz_503BeforeDiscovery(t *testing.T) {
h := newReadyzHandler(&readyStub{ready: false})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/readyz", nil))
if rec.Code != http.StatusServiceUnavailable {
t.Fatalf("readyz pre-discovery: want 503, got %d", rec.Code)
}
}
func TestReadyz_200AfterDiscovery(t *testing.T) {
h := newReadyzHandler(&readyStub{ready: true})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/readyz", nil))
if rec.Code != http.StatusOK {
t.Fatalf("readyz post-discovery: want 200, got %d", rec.Code)
}
}
+15
View File
@@ -0,0 +1,15 @@
package main
import "github.com/lukaszraczylo/traefikoidc"
type traefikoidcConfigStub struct {
callbackURL string
logoutURL string
}
func (s traefikoidcConfigStub) AsOIDC() traefikoidc.Config {
return traefikoidc.Config{
CallbackURL: s.callbackURL,
LogoutURL: s.logoutURL,
}
}
+195
View File
@@ -0,0 +1,195 @@
package main
import (
"context"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/lukaszraczylo/traefikoidc"
)
// fakeProviderHost is a synthetic hostname used in place of the httptest.Server's
// 127.0.0.1 address. traefikoidc's URL validator blocks loopback IPs
// unconditionally; a non-loopback hostname passes the check. The custom HTTP
// client returned by mockHTTPClient rewires all dials for this host to the
// actual test-server port, so the mock IdP still receives every request.
const fakeProviderHost = "test-oidc-provider.local"
// mockHTTPClient returns an *http.Client whose dialer transparently redirects
// connections to fakeProviderHost to the real httptest.Server address.
func mockHTTPClient(realAddr string) *http.Client {
dialer := &net.Dialer{Timeout: 5 * time.Second}
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return dialer.DialContext(ctx, network, addr)
}
if host == fakeProviderHost {
addr = realAddr
}
return dialer.DialContext(ctx, network, addr)
},
}
return &http.Client{Transport: transport}
}
// newMockIdP returns an httptest.Server that serves the minimal OIDC
// discovery surface required by traefikoidc.NewWithContext to bootstrap
// — discovery doc + an empty JWKS. All URLs in the discovery doc use
// fakeProviderHost so they pass the middleware's URL security validator.
func newMockIdP(t *testing.T) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
fakeBase := "http://" + fakeProviderHost
mux.HandleFunc("/.well-known/openid-configuration", func(rw http.ResponseWriter, _ *http.Request) {
discovery := map[string]any{
"issuer": fakeBase,
"authorization_endpoint": fakeBase + "/authorize",
"token_endpoint": fakeBase + "/token",
"jwks_uri": fakeBase + "/jwks",
"response_types_supported": []string{"code"},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{"RS256"},
}
rw.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(rw).Encode(discovery)
})
mux.HandleFunc("/jwks", func(rw http.ResponseWriter, _ *http.Request) {
rw.Header().Set("Content-Type", "application/json")
_, _ = rw.Write([]byte(`{"keys":[]}`))
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv
}
// buildTestConfig produces a Config that points at the fake provider hostname
// (which the custom HTTP client redirects to the real mock server) and uses
// a known-good SessionEncryptionKey + safe path defaults.
func buildTestConfig(srv *httptest.Server) *Config {
// realAddr is HOST:PORT of the httptest server (e.g. "127.0.0.1:56789").
realAddr := srv.Listener.Addr().String()
cfg := &Config{
Listen: "127.0.0.1:0", // unused — we drive the mux directly via httptest
AuthPath: "/oauth2/auth",
StartPath: "/oauth2/start",
OIDC: traefikoidc.Config{
ProviderURL: "http://" + fakeProviderHost,
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
CallbackURL: "/oauth2/callback",
LogoutURL: "/oauth2/logout",
TrustForwardedURI: true,
EnablePKCE: true,
HTTPClient: mockHTTPClient(realAddr),
},
}
return cfg
}
// buildIntegrationStack builds the same wiring main.go builds: real
// middleware constructed against the mock IdP, success handler as next,
// mux on top.
func buildIntegrationStack(t *testing.T, idp *httptest.Server) (http.Handler, *traefikoidc.TraefikOidc) {
t.Helper()
cfg := buildTestConfig(idp)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
mw, err := traefikoidc.NewWithContext(ctx, &cfg.OIDC, newSuccessHandler(), "oidcgate-test")
if err != nil {
t.Fatalf("NewWithContext: %v", err)
}
mux := buildMux(cfg, mw, mw)
return mux, mw
}
func TestIntegration_UnauthenticatedAuthReturns401WithRedirect(t *testing.T) {
idp := newMockIdP(t)
mux, _ := buildIntegrationStack(t, idp)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/oauth2/auth", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: want 401, got %d (body=%q)", rec.Code, rec.Body.String())
}
loc := rec.Header().Get("X-Auth-Redirect")
if loc == "" {
t.Fatal("X-Auth-Redirect should carry the IdP authorize URL")
}
if !strings.HasPrefix(loc, "http://"+fakeProviderHost+"/authorize") {
t.Errorf("X-Auth-Redirect should point at the mock IdP authorize endpoint, got %q", loc)
}
if cookies := rec.Header().Values("Set-Cookie"); len(cookies) == 0 {
t.Error("expected at least one Set-Cookie (state/PKCE/nonce) on 401")
}
}
func TestIntegration_StartRedirectsToIdPWithStateAndPKCE(t *testing.T) {
idp := newMockIdP(t)
mux, _ := buildIntegrationStack(t, idp)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/oauth2/start?rd=/dashboard", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusFound {
t.Fatalf("status: want 302, got %d", rec.Code)
}
loc := rec.Header().Get("Location")
if !strings.HasPrefix(loc, "http://"+fakeProviderHost+"/authorize") {
t.Fatalf("Location: want prefix http://%s/authorize, got %q", fakeProviderHost, loc)
}
if !strings.Contains(loc, "state=") {
t.Errorf("Location should include state= param, got %q", loc)
}
if !strings.Contains(loc, "code_challenge=") {
t.Errorf("Location should include code_challenge= param (PKCE), got %q", loc)
}
}
func TestIntegration_HealthzAlways200(t *testing.T) {
idp := newMockIdP(t)
mux, _ := buildIntegrationStack(t, idp)
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/healthz", nil))
if rec.Code != http.StatusOK {
t.Fatalf("healthz: want 200, got %d", rec.Code)
}
}
func TestIntegration_ReadyzBecomes200AfterDiscovery(t *testing.T) {
idp := newMockIdP(t)
mux, mw := buildIntegrationStack(t, idp)
// Hit /oauth2/auth once to trigger metadata discovery (the middleware
// performs discovery lazily on first request).
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/oauth2/auth", nil))
// Poll Ready() until true or timeout.
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
if mw.Ready() {
break
}
time.Sleep(50 * time.Millisecond)
}
if !mw.Ready() {
t.Fatal("middleware should be Ready() within 3s after first request triggered discovery")
}
rec = httptest.NewRecorder()
mux.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/readyz", nil))
if rec.Code != http.StatusOK {
t.Fatalf("readyz post-discovery: want 200, got %d", rec.Code)
}
}
+83
View File
@@ -0,0 +1,83 @@
package main
import "net/http"
// authInterceptor wraps a ResponseWriter for the /oauth2/auth endpoint.
// The traefikoidc middleware emits an HTTP 302 to the IdP authorize URL
// when a request is unauthenticated, but nginx auth_request and similar
// silent-probe contracts cannot follow redirects. authInterceptor buffers
// the header/body and, at Finalize() time:
//
// - if status was a redirect class (302, 303, 307, 308), rewrites it
// to 401, moves the original Location header to X-Auth-Redirect
// (advisory), strips Location, preserves Set-Cookie headers (state,
// PKCE, nonce — the browser will carry them into the next request),
// and writes an empty body.
// - otherwise: passes through verbatim.
type authInterceptor struct {
inner http.ResponseWriter
headers http.Header
status int
body []byte
wroteHeader bool
finalized bool
}
func newAuthInterceptor(inner http.ResponseWriter) *authInterceptor {
return &authInterceptor{
inner: inner,
headers: http.Header{},
status: http.StatusOK,
}
}
func (w *authInterceptor) Header() http.Header { return w.headers }
func (w *authInterceptor) WriteHeader(status int) {
if w.wroteHeader {
return
}
w.status = status
w.wroteHeader = true
}
func (w *authInterceptor) Write(b []byte) (int, error) { //nolint:unparam // signature mandated by http.ResponseWriter
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
w.body = append(w.body, b...)
return len(b), nil
}
// Finalize flushes the buffered response, applying the 302/303 → 401 rewrite.
// Must be called exactly once after the wrapped handler returns.
func (w *authInterceptor) Finalize() {
if w.finalized {
return
}
w.finalized = true
switch w.status {
case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect:
// Move Location → X-Auth-Redirect, strip Location, force 401, drop body.
if loc := w.headers.Get("Location"); loc != "" {
w.headers.Set("X-Auth-Redirect", loc)
w.headers.Del("Location")
}
copyHeaders(w.inner.Header(), w.headers)
w.inner.WriteHeader(http.StatusUnauthorized)
return
}
copyHeaders(w.inner.Header(), w.headers)
w.inner.WriteHeader(w.status)
if len(w.body) > 0 {
_, _ = w.inner.Write(w.body)
}
}
func copyHeaders(dst, src http.Header) {
for k, vs := range src {
for _, v := range vs {
dst.Add(k, v)
}
}
}
+108
View File
@@ -0,0 +1,108 @@
package main
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestInterceptor_302BecomesNot401(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("Location", "https://idp.example/authorize?state=abc")
w.Header().Add("Set-Cookie", "_oidc_state=abc; Path=/; HttpOnly")
w.Header().Add("Set-Cookie", "_oidc_pkce=xyz; Path=/; HttpOnly")
w.WriteHeader(http.StatusFound)
_, _ = w.Write([]byte("ignored body"))
w.Finalize()
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: want 401, got %d", rec.Code)
}
if got := rec.Header().Get("X-Auth-Redirect"); got != "https://idp.example/authorize?state=abc" {
t.Errorf("X-Auth-Redirect: want preserved Location, got %q", got)
}
if got := rec.Header().Get("Location"); got != "" {
t.Errorf("Location must be stripped on 401, got %q", got)
}
cookies := rec.Header().Values("Set-Cookie")
if len(cookies) != 2 {
t.Fatalf("Set-Cookie count: want 2, got %d (%v)", len(cookies), cookies)
}
if body := strings.TrimSpace(rec.Body.String()); body != "" {
t.Errorf("body must be empty on 401, got %q", body)
}
}
func TestInterceptor_NonRedirectPassthrough(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("X-Forwarded-User", "alice")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
w.Finalize()
if rec.Code != http.StatusOK {
t.Fatalf("status: want 200, got %d", rec.Code)
}
if got := rec.Header().Get("X-Forwarded-User"); got != "alice" {
t.Errorf("X-Forwarded-User: want preserved, got %q", got)
}
if !strings.Contains(rec.Body.String(), "ok") {
t.Errorf("body: want 'ok' preserved, got %q", rec.Body.String())
}
}
func TestInterceptor_303SeeOtherAlsoIntercepted(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("Location", "/elsewhere")
w.WriteHeader(http.StatusSeeOther)
w.Finalize()
if rec.Code != http.StatusUnauthorized {
t.Fatalf("303 should be intercepted to 401, got %d", rec.Code)
}
}
func TestInterceptor_307TemporaryRedirectIntercepted(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("Location", "/elsewhere")
w.WriteHeader(http.StatusTemporaryRedirect)
w.Finalize()
if rec.Code != http.StatusUnauthorized {
t.Fatalf("307 should be intercepted to 401, got %d", rec.Code)
}
}
func TestInterceptor_308PermanentRedirectIntercepted(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("Location", "/elsewhere")
w.WriteHeader(http.StatusPermanentRedirect)
w.Finalize()
if rec.Code != http.StatusUnauthorized {
t.Fatalf("308 should be intercepted to 401, got %d", rec.Code)
}
}
func TestInterceptor_DoubleFinalizeIsNoop(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("X-Forwarded-User", "alice")
w.WriteHeader(http.StatusOK)
w.Finalize()
// Second call must not panic, must not change anything observable.
w.Finalize()
if rec.Code != http.StatusOK {
t.Fatalf("double Finalize must not change status, got %d", rec.Code)
}
if got := rec.Header().Get("X-Forwarded-User"); got != "alice" {
t.Errorf("double Finalize must not duplicate headers, got %q", got)
}
}
+54
View File
@@ -0,0 +1,54 @@
package main
import (
"context"
"errors"
"flag"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/lukaszraczylo/traefikoidc"
)
func main() {
configPath := flag.String("config", "/etc/oidcgate/config.yaml", "Path to YAML config file")
flag.Parse()
cfg, err := Load(*configPath)
if err != nil {
log.Fatalf("oidcgate: load config: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
success := newSuccessHandler()
middleware, err := traefikoidc.NewWithContext(ctx, &cfg.OIDC, success, "oidcgate")
if err != nil {
cancel()
log.Fatalf("oidcgate: build middleware: %v", err)
}
mux := buildMux(cfg, middleware, middleware)
srv := buildServer(cfg, mux)
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
<-sigs
log.Println("oidcgate: shutdown signal received")
if err := shutdown(srv); err != nil {
log.Printf("oidcgate: shutdown error: %v", err)
}
}()
log.Printf("oidcgate: listening on %s", cfg.Listen)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
cancel()
log.Fatalf("oidcgate: serve: %v", err)
}
log.Println("oidcgate: stopped")
}
+43
View File
@@ -0,0 +1,43 @@
package main
import (
"context"
"net/http"
"time"
)
// buildMux wires all six routes onto a single ServeMux:
//
// /healthz, /readyz, AuthPath, StartPath, OIDC.CallbackURL, OIDC.LogoutURL.
//
// The same `middleware` instance is delegated to by all four OIDC routes;
// the synthetic success handler is wired into the middleware at construction
// time (in main.go) so it doesn't appear here.
func buildMux(cfg *Config, middleware http.Handler, ready readyReporter) *http.ServeMux {
mux := http.NewServeMux()
mux.Handle("/healthz", newHealthzHandler())
mux.Handle("/readyz", newReadyzHandler(ready))
mux.Handle(cfg.AuthPath, newAuthHandler(middleware))
mux.Handle(cfg.StartPath, newStartHandler(middleware))
mux.Handle(cfg.OIDC.CallbackURL, newCallbackHandler(middleware, cfg.OIDC.CallbackURL))
mux.Handle(cfg.OIDC.LogoutURL, newLogoutHandler(middleware, cfg.OIDC.LogoutURL))
return mux
}
// buildServer wraps the mux in an http.Server with sensible timeouts.
func buildServer(cfg *Config, mux http.Handler) *http.Server { //nolint:unused // consumed by main.go in Task 9
return &http.Server{
Addr: cfg.Listen,
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
}
// shutdown gracefully stops the server with a 15s deadline.
func shutdown(srv *http.Server) error { //nolint:unused // consumed by main.go in Task 9
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return srv.Shutdown(ctx)
}
+42
View File
@@ -0,0 +1,42 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestMux_RoutesAllEndpoints(t *testing.T) {
stub := &stubMiddleware{
fn: func(rw http.ResponseWriter, _ *http.Request) { rw.WriteHeader(http.StatusOK) },
}
mux := buildMux(&Config{
Listen: ":0",
AuthPath: "/oauth2/auth",
StartPath: "/oauth2/start",
OIDC: traefikoidcConfigStub{
callbackURL: "/oauth2/callback",
logoutURL: "/oauth2/logout",
}.AsOIDC(),
}, stub, &readyStub{ready: true})
cases := []struct {
path string
method string
want int
}{
{"/healthz", http.MethodGet, http.StatusOK},
{"/readyz", http.MethodGet, http.StatusOK},
{"/oauth2/auth", http.MethodGet, http.StatusOK},
{"/oauth2/start", http.MethodGet, http.StatusOK},
{"/oauth2/callback", http.MethodGet, http.StatusOK},
{"/oauth2/logout", http.MethodPost, http.StatusOK},
}
for _, c := range cases {
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, httptest.NewRequest(c.method, c.path, nil))
if rec.Code != c.want {
t.Errorf("%s %s: want %d, got %d", c.method, c.path, c.want, rec.Code)
}
}
}
+43
View File
@@ -0,0 +1,43 @@
package main
import (
"net/http"
"strings"
)
// mirrorAllowedHeaders is the set of NON-X-prefixed request headers that the
// success handler copies onto the response. The traefikoidc middleware sets
// "Authorization: Bearer ..." via the templated-header feature when operators
// configure it, and proxies need that to flow upstream.
var mirrorAllowedHeaders = map[string]struct{}{
"Authorization": {},
}
// successHandler is the http.Handler installed as the middleware's `next`.
// When the middleware reaches this handler the request is authenticated; we
// mirror the X-* (and a small allow-list of non-X-*) headers the middleware
// stamped onto req.Header back onto the response so upstream proxies can
// capture them via auth_request_set / authResponseHeaders / copy_headers,
// then write 200 with an empty body.
func newSuccessHandler() http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for name, values := range req.Header {
if !shouldMirror(name) {
continue
}
for _, v := range values {
rw.Header().Add(name, v)
}
}
rw.WriteHeader(http.StatusOK)
})
}
func shouldMirror(name string) bool {
if strings.HasPrefix(name, "X-") {
return true
}
canonical := http.CanonicalHeaderKey(name)
_, ok := mirrorAllowedHeaders[canonical]
return ok
}
+70
View File
@@ -0,0 +1,70 @@
package main
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestSuccessHandler_Writes200(t *testing.T) {
h := newSuccessHandler()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/x", nil)
h.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status: want 200, got %d", rec.Code)
}
}
func TestSuccessHandler_MirrorsForwardedHeaders(t *testing.T) {
h := newSuccessHandler()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/x", nil)
req.Header.Set("X-Forwarded-User", "alice@example.com")
req.Header.Set("X-Forwarded-Email", "alice@example.com")
req.Header.Set("X-Custom-Templated", "value")
req.Header.Set("Authorization", "Bearer token-from-template")
req.Header.Set("Cookie", "session=should-NOT-mirror")
h.ServeHTTP(rec, req)
if got := rec.Header().Get("X-Forwarded-User"); got != "alice@example.com" {
t.Errorf("X-Forwarded-User: want mirrored, got %q", got)
}
if got := rec.Header().Get("X-Forwarded-Email"); got != "alice@example.com" {
t.Errorf("X-Forwarded-Email: want mirrored, got %q", got)
}
if got := rec.Header().Get("X-Custom-Templated"); got != "value" {
t.Errorf("X-Custom-Templated: want mirrored (X- prefix), got %q", got)
}
if got := rec.Header().Get("Authorization"); got != "Bearer token-from-template" {
t.Errorf("Authorization: want mirrored (templated bearer), got %q", got)
}
if got := rec.Header().Get("Cookie"); got != "" {
t.Errorf("Cookie must NOT be mirrored, got %q", got)
}
}
func TestSuccessHandler_EmptyBody(t *testing.T) {
h := newSuccessHandler()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/x", nil)
h.ServeHTTP(rec, req)
if body := strings.TrimSpace(rec.Body.String()); body != "" {
t.Fatalf("body: want empty, got %q", body)
}
}
func TestSuccessHandler_MultiValueHeader(t *testing.T) {
h := newSuccessHandler()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/x", nil)
req.Header.Add("X-Role", "admin")
req.Header.Add("X-Role", "editor")
h.ServeHTTP(rec, req)
got := rec.Header()["X-Role"]
if len(got) != 2 || got[0] != "admin" || got[1] != "editor" {
t.Errorf("X-Role multi-value: want [admin editor], got %v", got)
}
}
File diff suppressed because it is too large Load Diff
-211
View File
@@ -1,211 +0,0 @@
// Package config provides configuration management for the OIDC middleware
package config
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"time"
)
const (
minEncryptionKeyLength = 16
ConstSessionTimeout = 86400
)
//lint:ignore U1000 May be referenced for default exclusion patterns
var defaultExcludedURLs = map[string]struct{}{
"/favicon.ico": {},
"/robots.txt": {},
"/health": {},
"/.well-known/": {},
"/metrics": {},
"/ping": {},
"/api/": {},
"/static/": {},
"/assets/": {},
"/js/": {},
"/css/": {},
"/images/": {},
"/fonts/": {},
}
// Settings manages configuration and initialization for the OIDC middleware
type Settings struct {
logger Logger
}
// Logger interface for dependency injection
type Logger interface {
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
}
// Config represents the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerUrl"`
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
CallbackURL string `json:"callbackUrl"`
LogoutURL string `json:"logoutUrl"`
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHttps"`
LogLevel string `json:"logLevel"`
Scopes []string `json:"scopes"`
OverrideScopes bool `json:"overrideScopes"`
AllowedUsers []string `json:"allowedUsers"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
ExcludedURLs []string `json:"excludedUrls"`
EnablePKCE bool `json:"enablePkce"`
RateLimit int `json:"rateLimit"`
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
Headers []HeaderConfig `json:"headers"`
HTTPClient *http.Client `json:"-"`
CookieDomain string `json:"cookieDomain"`
}
// HeaderConfig represents header template configuration
type HeaderConfig struct {
Name string `json:"name"`
Value string `json:"value"`
}
// NewSettings creates a new Settings instance
func NewSettings(logger Logger) *Settings {
return &Settings{
logger: logger,
}
}
// CreateConfig creates a default configuration
func CreateConfig() *Config {
return &Config{
LogLevel: "INFO",
ForceHTTPS: true,
EnablePKCE: true,
RateLimit: 10,
RefreshGracePeriodSeconds: 60,
Scopes: []string{"openid", "profile", "email"},
Headers: []HeaderConfig{},
}
}
// InitializeTraefikOidc would initialize and configure a new TraefikOidc instance
// This functionality has been moved to the main New function in main.go
// This function is kept for compatibility but should not be used
func (s *Settings) InitializeTraefikOidc(ctx context.Context, next http.Handler, config *Config, name string) (interface{}, error) {
return nil, fmt.Errorf("InitializeTraefikOidc is deprecated - use New function from main package instead")
}
//lint:ignore U1000 Kept for backward compatibility
func (s *Settings) setupHeaderTemplates(t interface{}, config *Config, logger Logger) error {
logger.Debug("setupHeaderTemplates is deprecated")
return nil
}
//lint:ignore U1000 May be needed for future background service management
func (s *Settings) startBackgroundServices(ctx context.Context, logger Logger) {
startReplayCacheCleanup(ctx, logger)
// Start memory monitoring for leak detection and performance insights
memoryMonitor := GetGlobalMemoryMonitor()
memoryMonitor.StartMonitoring(ctx, 60*time.Second) // Monitor every minute
logger.Debug("Started global memory monitoring")
}
// Utility functions
//lint:ignore U1000 May be needed for future scope processing
func deduplicateScopes(scopes []string) []string {
seen := make(map[string]bool)
result := []string{}
for _, scope := range scopes {
if !seen[scope] {
seen[scope] = true
result = append(result, scope)
}
}
return result
}
//lint:ignore U1000 May be needed for future scope merging operations
func mergeScopes(defaultScopes, userScopes []string) []string {
result := make([]string, len(defaultScopes))
copy(result, defaultScopes)
return append(result, userScopes...)
}
//lint:ignore U1000 May be needed for future utility operations
func createStringMap(items []string) map[string]struct{} {
result := make(map[string]struct{})
for _, item := range items {
result[item] = struct{}{}
}
return result
}
//lint:ignore U1000 May be needed for future case-insensitive operations
func createCaseInsensitiveStringMap(items []string) map[string]struct{} {
result := make(map[string]struct{})
for _, item := range items {
result[strings.ToLower(item)] = struct{}{}
}
return result
}
//lint:ignore U1000 May be needed for future test environment detection
func isTestMode() bool {
// This function should be implemented based on environment detection logic
return false
}
// External dependencies that need to be provided
// TraefikOidc struct is defined in types.go
// These functions need to be provided by external packages
func NewLogger(level string) Logger { return nil }
func CreateDefaultHTTPClient() *http.Client { return nil }
func CreateTokenHTTPClient() *http.Client { return nil }
func GetGlobalCacheManager(*sync.WaitGroup) CacheManager { return nil }
func NewSessionManager(string, bool, string, Logger) (SessionManager, error) { return nil, nil }
func NewErrorRecoveryManager(Logger) ErrorRecoveryManager { return nil }
//lint:ignore U1000 May be needed for future token claim extraction
func extractClaims(string) (map[string]interface{}, error) { return nil, nil }
//lint:ignore U1000 May be needed for future replay attack prevention
func startReplayCacheCleanup(context.Context, Logger) {}
func GetGlobalMemoryMonitor() MemoryMonitor { return nil }
// Interfaces for external dependencies
type CacheManager interface {
GetSharedTokenBlacklist() CacheInterface
GetSharedTokenCache() *TokenCache
GetSharedMetadataCache() *MetadataCache
GetSharedJWKCache() JWKCacheInterface
Close() error
}
type SessionManager interface{}
type ErrorRecoveryManager interface{}
type MemoryMonitor interface {
StartMonitoring(ctx context.Context, interval time.Duration)
}
type CacheInterface interface {
Set(key string, value interface{}, ttl time.Duration)
Get(key string) (interface{}, bool)
Delete(key string)
SetMaxSize(size int)
Cleanup()
Close()
}
type TokenCache struct{}
type MetadataCache struct{}
type JWKCacheInterface interface{}
+116
View File
@@ -0,0 +1,116 @@
package traefikoidc
import (
"encoding/json"
)
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalJSON() ([]byte, error) {
// Build a map manually to avoid type alias issues with yaegi
result := make(map[string]interface{})
// Copy public fields
result["providerURL"] = c.ProviderURL
result["clientID"] = c.ClientID
result["callbackURL"] = c.CallbackURL
result["logoutURL"] = c.LogoutURL
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
result["scopes"] = c.Scopes
result["forceHTTPS"] = c.ForceHTTPS
result["logLevel"] = c.LogLevel
result["rateLimit"] = c.RateLimit
result["excludedURLs"] = c.ExcludedURLs
result["allowedUserDomains"] = c.AllowedUserDomains
result["allowedUsers"] = c.AllowedUsers
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
// Redact sensitive fields
result["clientSecret"] = REDACTED
result["sessionEncryptionKey"] = REDACTED
// Handle Redis config
if c.Redis != nil {
redisMap := make(map[string]interface{})
redisMap["enabled"] = c.Redis.Enabled
redisMap["address"] = c.Redis.Address
redisMap["password"] = REDACTED
redisMap["db"] = c.Redis.DB
redisMap["poolSize"] = c.Redis.PoolSize
redisMap["cacheMode"] = c.Redis.CacheMode
result["redis"] = redisMap
}
return json.Marshal(result)
}
// MarshalYAML implements custom YAML marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalYAML() (interface{}, error) {
// Build a map manually to avoid type alias issues with yaegi
result := make(map[string]interface{})
// Copy public fields
result["providerURL"] = c.ProviderURL
result["clientID"] = c.ClientID
result["callbackURL"] = c.CallbackURL
result["logoutURL"] = c.LogoutURL
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
result["scopes"] = c.Scopes
result["forceHTTPS"] = c.ForceHTTPS
result["logLevel"] = c.LogLevel
result["rateLimit"] = c.RateLimit
result["excludedURLs"] = c.ExcludedURLs
result["allowedUserDomains"] = c.AllowedUserDomains
result["allowedUsers"] = c.AllowedUsers
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
// Redact sensitive fields
result["clientSecret"] = REDACTED
result["sessionEncryptionKey"] = REDACTED
// Handle Redis config
if c.Redis != nil {
redisMap := make(map[string]interface{})
redisMap["enabled"] = c.Redis.Enabled
redisMap["address"] = c.Redis.Address
redisMap["password"] = REDACTED
redisMap["db"] = c.Redis.DB
redisMap["poolSize"] = c.Redis.PoolSize
redisMap["cacheMode"] = c.Redis.CacheMode
result["redis"] = redisMap
}
return result, nil
}
// MarshalJSON for RedisConfig to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (r RedisConfig) MarshalJSON() ([]byte, error) {
result := make(map[string]interface{})
result["enabled"] = r.Enabled
result["address"] = r.Address
result["password"] = REDACTED
result["db"] = r.DB
result["poolSize"] = r.PoolSize
result["cacheMode"] = r.CacheMode
return json.Marshal(result)
}
// MarshalYAML for RedisConfig to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (r RedisConfig) MarshalYAML() (interface{}, error) {
result := make(map[string]interface{})
result["enabled"] = r.Enabled
result["address"] = r.Address
result["password"] = REDACTED
result["db"] = r.DB
result["poolSize"] = r.PoolSize
result["cacheMode"] = r.CacheMode
return result, nil
}
File diff suppressed because it is too large Load Diff
+12 -12
View File
@@ -18,7 +18,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test that CSRF tokens persist through the authentication flow
t.Run("CSRF_Token_Persists_After_Selective_Clear", func(t *testing.T) {
// Create a session manager
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Create initial request
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
session.SetCSRF(csrfToken)
session.SetNonce("test-nonce")
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("old-access-token")
session.SetRefreshToken("old-refresh-token")
session.SetIDToken("old-id-token")
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Now perform selective clearing (as done in the fix)
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
@@ -90,7 +90,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test that marking session as dirty forces save
t.Run("Mark_Dirty_Forces_Session_Save", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
@@ -126,7 +126,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test Azure-specific session handling
t.Run("Azure_Session_Cookie_Configuration", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Simulate Azure callback scenario
@@ -158,7 +158,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test session continuity through auth flow
t.Run("Session_Continuity_Through_Auth_Flow", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Step 1: Initial request
@@ -199,7 +199,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Test large token handling doesn't affect CSRF
t.Run("Large_Tokens_Dont_Affect_CSRF", func(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
req := httptest.NewRequest("GET", "http://example.com/test", nil)
@@ -262,7 +262,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
// We can't fully initialize TraefikOidc without network access,
// but we can test the session management directly
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", NewLogger(plugin.LogLevel))
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", "", 0, NewLogger(plugin.LogLevel))
require.NoError(t, err)
t.Run("Session_Created_On_Protected_Request", func(t *testing.T) {
@@ -291,7 +291,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
// TestRegressionLoginLoop specifically tests the fix for issue #53
func TestRegressionLoginLoop(t *testing.T) {
// This test verifies that the specific changes made to fix the login loop work correctly
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
// Simulate the exact flow that was causing the login loop
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// Set initial session data
session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-token")
session.SetCSRF("existing-csrf")
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
// NEW BEHAVIOR: Selective clearing
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
@@ -392,7 +392,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// TestCSRFValidationTiming tests timing-sensitive CSRF validation scenarios
func TestCSRFValidationTiming(t *testing.T) {
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
require.NoError(t, err)
t.Run("Rapid_Redirect_Maintains_CSRF", func(t *testing.T) {
+364
View File
@@ -0,0 +1,364 @@
//go:build !yaegi
package traefikoidc
import (
"testing"
)
// TestCustomClaimNames_DefaultBehavior tests backward compatibility with default claim names
func TestCustomClaimNames_DefaultBehavior(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Explicitly set defaults to test backward compatibility
ts.tOidc.roleClaimName = "roles"
ts.tOidc.groupClaimName = "groups"
// Test that when no custom claim names are configured, it uses defaults "roles" and "groups"
claims := map[string]interface{}{
"groups": []interface{}{"admin", "users"},
"roles": []interface{}{"editor", "viewer"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin", "users"}) {
t.Errorf("Expected groups [admin users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
t.Errorf("Expected roles [editor viewer], got %v", roles)
}
}
// TestCustomClaimNames_Auth0Namespaced tests Auth0-style namespaced claims
func TestCustomClaimNames_Auth0Namespaced(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names for Auth0
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with Auth0-style namespaced claims
claims := map[string]interface{}{
"https://myapp.com/groups": []interface{}{"admin", "users"},
"https://myapp.com/roles": []interface{}{"editor", "viewer"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin", "users"}) {
t.Errorf("Expected groups [admin users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
t.Errorf("Expected roles [editor viewer], got %v", roles)
}
}
// TestCustomClaimNames_CustomSimpleNames tests custom simple claim names
func TestCustomClaimNames_CustomSimpleNames(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom simple claim names
ts.tOidc.roleClaimName = "user_roles"
ts.tOidc.groupClaimName = "user_groups"
// Create token with custom claim names
claims := map[string]interface{}{
"user_groups": []interface{}{"engineering", "product"},
"user_roles": []interface{}{"developer", "manager"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"engineering", "product"}) {
t.Errorf("Expected groups [engineering product], got %v", groups)
}
if !stringSliceEqual(roles, []string{"developer", "manager"}) {
t.Errorf("Expected roles [developer manager], got %v", roles)
}
}
// TestCustomClaimNames_MissingClaims tests behavior when custom claims are missing
func TestCustomClaimNames_MissingClaims(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
ts.tOidc.groupClaimName = "custom_groups"
// Create token WITHOUT the custom claims
claims := map[string]interface{}{
"sub": "user123",
"email": "user@example.com",
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should return empty slices, not error
if len(groups) != 0 {
t.Errorf("Expected empty groups, got %v", groups)
}
if len(roles) != 0 {
t.Errorf("Expected empty roles, got %v", roles)
}
}
// TestCustomClaimNames_MalformedClaims tests error handling for malformed claims
func TestCustomClaimNames_MalformedRoleClaim(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
// Create token with malformed role claim (not an array)
claims := map[string]interface{}{
"custom_roles": "this-should-be-an-array",
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
if err == nil {
t.Error("Expected error for malformed role claim, got nil")
}
// Check error message contains the custom claim name
expectedError := "custom_roles claim is not an array"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
// TestCustomClaimNames_MalformedGroupClaim tests error handling for malformed group claims
func TestCustomClaimNames_MalformedGroupClaim(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.groupClaimName = "custom_groups"
// Create token with malformed group claim (not an array)
claims := map[string]interface{}{
"custom_groups": 12345, // Not an array
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
if err == nil {
t.Error("Expected error for malformed group claim, got nil")
}
// Check error message contains the custom claim name
expectedError := "custom_groups claim is not an array"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
// TestCustomClaimNames_PartialConfiguration tests when only one claim name is customized
func TestCustomClaimNames_OnlyRoleCustomized(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure only role claim name (group uses default)
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "groups" // default
// Create token with mixed claim names
claims := map[string]interface{}{
"groups": []interface{}{"admin"},
"https://myapp.com/roles": []interface{}{"editor"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"admin"}) {
t.Errorf("Expected groups [admin], got %v", groups)
}
if !stringSliceEqual(roles, []string{"editor"}) {
t.Errorf("Expected roles [editor], got %v", roles)
}
}
// TestCustomClaimNames_OnlyGroupCustomized tests when only group claim name is customized
func TestCustomClaimNames_OnlyGroupCustomized(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure only group claim name (role uses default)
ts.tOidc.roleClaimName = "roles" // default
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with mixed claim names
claims := map[string]interface{}{
"roles": []interface{}{"viewer"},
"https://myapp.com/groups": []interface{}{"users"},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !stringSliceEqual(groups, []string{"users"}) {
t.Errorf("Expected groups [users], got %v", groups)
}
if !stringSliceEqual(roles, []string{"viewer"}) {
t.Errorf("Expected roles [viewer], got %v", roles)
}
}
// TestCustomClaimNames_EmptyArrays tests extraction with empty claim arrays
func TestCustomClaimNames_EmptyArrays(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "https://myapp.com/roles"
ts.tOidc.groupClaimName = "https://myapp.com/groups"
// Create token with empty arrays
claims := map[string]interface{}{
"https://myapp.com/groups": []interface{}{},
"https://myapp.com/roles": []interface{}{},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(groups) != 0 {
t.Errorf("Expected empty groups, got %v", groups)
}
if len(roles) != 0 {
t.Errorf("Expected empty roles, got %v", roles)
}
}
// TestCustomClaimNames_NonStringElements tests handling of non-string elements in claim arrays
func TestCustomClaimNames_NonStringInRoleArray(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.roleClaimName = "custom_roles"
// Create token with mixed-type array (should skip non-string elements)
claims := map[string]interface{}{
"custom_roles": []interface{}{"role1", 12345, "role2", true},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
_, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should only extract string elements
if !stringSliceEqual(roles, []string{"role1", "role2"}) {
t.Errorf("Expected roles [role1 role2], got %v", roles)
}
}
// TestCustomClaimNames_NonStringInGroupArray tests handling of non-string elements in group arrays
func TestCustomClaimNames_NonStringInGroupArray(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Configure custom claim names
ts.tOidc.groupClaimName = "custom_groups"
// Create token with mixed-type array (should skip non-string elements)
claims := map[string]interface{}{
"custom_groups": []interface{}{"group1", nil, "group2", 3.14},
}
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, _, err := ts.tOidc.extractGroupsAndRoles(token)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Should only extract string elements
if !stringSliceEqual(groups, []string{"group1", "group2"}) {
t.Errorf("Expected groups [group1 group2], got %v", groups)
}
}
+290
View File
@@ -0,0 +1,290 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"context"
"fmt"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/dcrstorage"
)
// DCRStorageBackend represents the type of storage backend for DCR credentials.
// Alias for internal package type for backward compatibility.
type DCRStorageBackend = dcrstorage.StorageBackend
const (
// DCRStorageBackendFile uses file-based storage (default for backward compatibility)
DCRStorageBackendFile DCRStorageBackend = dcrstorage.StorageBackendFile
// DCRStorageBackendRedis uses Redis for distributed storage
DCRStorageBackendRedis DCRStorageBackend = dcrstorage.StorageBackendRedis
// DCRStorageBackendAuto automatically selects Redis if available, otherwise file
DCRStorageBackendAuto DCRStorageBackend = dcrstorage.StorageBackendAuto
)
// DCRCredentialsStore defines the interface for storing DCR credentials.
// This abstraction allows different storage backends (file, Redis) to be used
// for persisting OIDC Dynamic Client Registration credentials across nodes.
type DCRCredentialsStore interface {
// Save stores the client registration response for a provider
// The providerURL is used as a key to support multi-tenant scenarios
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
// Load retrieves stored credentials for a provider
// Returns nil, nil if no credentials exist (not an error)
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
// Delete removes stored credentials for a provider
Delete(ctx context.Context, providerURL string) error
// Exists checks if credentials exist for a provider
Exists(ctx context.Context, providerURL string) (bool, error)
}
// loggerAdapter adapts our Logger to the dcrstorage.Logger interface
type loggerAdapter struct {
logger *Logger
}
func (l *loggerAdapter) Debug(msg string) { l.logger.Debug("%s", msg) }
func (l *loggerAdapter) Debugf(format string, args ...any) { l.logger.Debugf(format, args...) }
func (l *loggerAdapter) Info(msg string) { l.logger.Info("%s", msg) }
func (l *loggerAdapter) Infof(format string, args ...any) { l.logger.Infof(format, args...) }
func (l *loggerAdapter) Error(msg string) { l.logger.Error("%s", msg) }
func (l *loggerAdapter) Errorf(format string, args ...any) { l.logger.Errorf(format, args...) }
// cacheAdapter adapts UniversalCache to dcrstorage.Cache interface
type cacheAdapter struct {
cache *UniversalCache
}
func (c *cacheAdapter) Get(key string) (any, bool) {
return c.cache.Get(key)
}
func (c *cacheAdapter) Set(key string, value any, ttl time.Duration) error {
return c.cache.Set(key, value, ttl)
}
func (c *cacheAdapter) Delete(key string) {
c.cache.Delete(key)
}
// fileStoreWrapper wraps dcrstorage.FileStore to implement DCRCredentialsStore
type fileStoreWrapper struct {
inner *dcrstorage.FileStore
}
func (w *fileStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
innerCreds := convertCredsToInternal(creds)
return w.inner.Save(ctx, providerURL, innerCreds)
}
func (w *fileStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
innerCreds, err := w.inner.Load(ctx, providerURL)
if err != nil || innerCreds == nil {
return nil, err
}
return convertCredsFromInternal(innerCreds), nil
}
func (w *fileStoreWrapper) Delete(ctx context.Context, providerURL string) error {
return w.inner.Delete(ctx, providerURL)
}
func (w *fileStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
return w.inner.Exists(ctx, providerURL)
}
// basePath returns the base path used for storing credentials (for backward compatibility in tests)
func (w *fileStoreWrapper) basePath() string {
return w.inner.BasePath()
}
// getFilePath returns the file path for storing credentials for a specific provider (for backward compatibility in tests)
func (w *fileStoreWrapper) getFilePath(providerURL string) string {
return w.inner.GetFilePath(providerURL)
}
// redisStoreWrapper wraps dcrstorage.RedisStore to implement DCRCredentialsStore
type redisStoreWrapper struct {
inner *dcrstorage.RedisStore
}
func (w *redisStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
innerCreds := convertCredsToInternal(creds)
return w.inner.Save(ctx, providerURL, innerCreds)
}
func (w *redisStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
innerCreds, err := w.inner.Load(ctx, providerURL)
if err != nil || innerCreds == nil {
return nil, err
}
return convertCredsFromInternal(innerCreds), nil
}
func (w *redisStoreWrapper) Delete(ctx context.Context, providerURL string) error {
return w.inner.Delete(ctx, providerURL)
}
func (w *redisStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
return w.inner.Exists(ctx, providerURL)
}
// FileCredentialsStore implements DCRCredentialsStore using file-based storage.
// This is the default storage backend for backward compatibility with existing deployments.
type FileCredentialsStore = fileStoreWrapper
// RedisCredentialsStore implements DCRCredentialsStore using Redis-backed cache.
// This storage backend enables sharing DCR credentials across multiple Traefik instances.
type RedisCredentialsStore = redisStoreWrapper
// NewFileCredentialsStore creates a new file-based credentials store.
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
func NewFileCredentialsStore(basePath string, logger *Logger) *FileCredentialsStore {
var dcrLogger dcrstorage.Logger
if logger != nil {
dcrLogger = &loggerAdapter{logger: logger}
}
inner := dcrstorage.NewFileStore(basePath, dcrLogger)
return &fileStoreWrapper{inner: inner}
}
// NewRedisCredentialsStore creates a new Redis-backed credentials store.
// The cache should be configured with a Redis backend for distributed storage.
// If keyPrefix is empty, defaults to "dcr:creds:"
func NewRedisCredentialsStore(cache *UniversalCache, keyPrefix string, logger *Logger) *RedisCredentialsStore {
var dcrLogger dcrstorage.Logger
if logger != nil {
dcrLogger = &loggerAdapter{logger: logger}
}
cacheAdapt := &cacheAdapter{cache: cache}
inner := dcrstorage.NewRedisStore(cacheAdapt, keyPrefix, dcrLogger)
return &redisStoreWrapper{inner: inner}
}
// Helper functions to convert between main package and internal package types
func convertCredsToInternal(creds *ClientRegistrationResponse) *dcrstorage.ClientRegistrationResponse {
if creds == nil {
return nil
}
return &dcrstorage.ClientRegistrationResponse{
SubjectType: creds.SubjectType,
LogoURI: creds.LogoURI,
RegistrationAccessToken: creds.RegistrationAccessToken,
RegistrationClientURI: creds.RegistrationClientURI,
Scope: creds.Scope,
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
TOSURI: creds.TOSURI,
PolicyURI: creds.PolicyURI,
ClientSecret: creds.ClientSecret,
ApplicationType: creds.ApplicationType,
ClientID: creds.ClientID,
ClientName: creds.ClientName,
JWKSURI: creds.JWKSURI,
ClientURI: creds.ClientURI,
Contacts: creds.Contacts,
GrantTypes: creds.GrantTypes,
ResponseTypes: creds.ResponseTypes,
RedirectURIs: creds.RedirectURIs,
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
ClientIDIssuedAt: creds.ClientIDIssuedAt,
}
}
func convertCredsFromInternal(creds *dcrstorage.ClientRegistrationResponse) *ClientRegistrationResponse {
if creds == nil {
return nil
}
return &ClientRegistrationResponse{
SubjectType: creds.SubjectType,
LogoURI: creds.LogoURI,
RegistrationAccessToken: creds.RegistrationAccessToken,
RegistrationClientURI: creds.RegistrationClientURI,
Scope: creds.Scope,
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
TOSURI: creds.TOSURI,
PolicyURI: creds.PolicyURI,
ClientSecret: creds.ClientSecret,
ApplicationType: creds.ApplicationType,
ClientID: creds.ClientID,
ClientName: creds.ClientName,
JWKSURI: creds.JWKSURI,
ClientURI: creds.ClientURI,
Contacts: creds.Contacts,
GrantTypes: creds.GrantTypes,
ResponseTypes: creds.ResponseTypes,
RedirectURIs: creds.RedirectURIs,
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
ClientIDIssuedAt: creds.ClientIDIssuedAt,
}
}
// NewDCRCredentialsStore creates a DCRCredentialsStore based on configuration.
// This factory function handles backend selection logic:
// - "file": Use file-based storage (default for backward compatibility)
// - "redis": Use Redis exclusively (fails if Redis unavailable)
// - "auto": Use Redis if available, fallback to file
func NewDCRCredentialsStore(
config *DynamicClientRegistrationConfig,
cacheManager *CacheManager,
logger *Logger,
) (DCRCredentialsStore, error) {
if config == nil {
return nil, fmt.Errorf("DCR config is nil")
}
if logger == nil {
logger = GetSingletonNoOpLogger()
}
backend := config.StorageBackend
if backend == "" {
backend = string(DCRStorageBackendAuto) // Default to auto selection
}
switch DCRStorageBackend(backend) {
case DCRStorageBackendFile:
logger.Info("Using file-based storage for DCR credentials")
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
case DCRStorageBackendRedis:
cache := getDCRCache(cacheManager)
if cache == nil {
return nil, fmt.Errorf("redis storage requested but Redis/cache not configured")
}
logger.Info("Using Redis storage for DCR credentials")
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
case DCRStorageBackendAuto:
// Try Redis first, fallback to file
cache := getDCRCache(cacheManager)
if cache != nil && cache.backend != nil {
logger.Info("Auto-selected Redis storage for DCR credentials")
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
}
logger.Info("Redis not available, using file storage for DCR credentials")
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
default:
return nil, fmt.Errorf("unknown DCR storage backend: %s", backend)
}
}
// getDCRCache safely retrieves the DCR credentials cache from the cache manager
func getDCRCache(cacheManager *CacheManager) *UniversalCache {
if cacheManager == nil {
return nil
}
cacheManager.mu.RLock()
defer cacheManager.mu.RUnlock()
if cacheManager.manager == nil {
return nil
}
return cacheManager.manager.GetDCRCredentialsCache()
}
+663
View File
@@ -0,0 +1,663 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
// TestFileCredentialsStore_SaveLoad tests the file-based credentials store
func TestFileCredentialsStore_SaveLoad(t *testing.T) {
t.Parallel()
// Create a temp directory for test files
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
testCreds := &ClientRegistrationResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "test-access-token",
RegistrationClientURI: "https://example.com/register/test-client-id",
RedirectURIs: []string{"https://app.example.com/callback"},
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "client_secret_basic",
}
ctx := context.Background()
providerURL := "https://auth.example.com"
t.Run("save and load credentials", func(t *testing.T) {
// Save credentials
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
// Load credentials
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
// Verify fields
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
tempDir2 := t.TempDir()
store2 := NewFileCredentialsStore(filepath.Join(tempDir2, "nonexistent.json"), logger)
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent file: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if exists {
t.Error("Expected credentials to not exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("delete non-existent credentials", func(t *testing.T) {
// Should not error
err := store.Delete(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Delete should not error for non-existent: %v", err)
}
})
}
// TestFileCredentialsStore_MultiProvider tests multi-provider support
func TestFileCredentialsStore_MultiProvider(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
provider1 := "https://auth1.example.com"
provider2 := "https://auth2.example.com"
creds1 := &ClientRegistrationResponse{
ClientID: "client-1",
ClientSecret: "secret-1",
}
creds2 := &ClientRegistrationResponse{
ClientID: "client-2",
ClientSecret: "secret-2",
}
// Save credentials for both providers
if err := store.Save(ctx, provider1, creds1); err != nil {
t.Fatalf("Failed to save creds1: %v", err)
}
if err := store.Save(ctx, provider2, creds2); err != nil {
t.Fatalf("Failed to save creds2: %v", err)
}
// Load and verify each provider's credentials
loaded1, err := store.Load(ctx, provider1)
if err != nil {
t.Fatalf("Failed to load creds1: %v", err)
}
if loaded1.ClientID != "client-1" {
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
}
loaded2, err := store.Load(ctx, provider2)
if err != nil {
t.Fatalf("Failed to load creds2: %v", err)
}
if loaded2.ClientID != "client-2" {
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
}
// Delete one shouldn't affect the other
if err := store.Delete(ctx, provider1); err != nil {
t.Fatalf("Failed to delete creds1: %v", err)
}
exists, _ := store.Exists(ctx, provider2)
if !exists {
t.Error("Provider 2 credentials should still exist")
}
}
// TestFileCredentialsStore_ConcurrentAccess tests thread safety
func TestFileCredentialsStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
creds := &ClientRegistrationResponse{
ClientID: "test-client",
ClientSecret: "test-secret",
}
var wg sync.WaitGroup
concurrency := 10
// Concurrent saves
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = store.Save(ctx, providerURL, creds)
}()
}
wg.Wait()
// Concurrent loads
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = store.Load(ctx, providerURL)
}()
}
wg.Wait()
// Final verification
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load after concurrent access: %v", err)
}
if loaded == nil || loaded.ClientID != "test-client" {
t.Error("Credentials corrupted after concurrent access")
}
}
// TestFileCredentialsStore_InvalidInput tests error handling
func TestFileCredentialsStore_InvalidInput(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
t.Run("empty provider URL uses default path", func(t *testing.T) {
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "", creds)
if err != nil {
t.Fatalf("Save with empty provider URL failed: %v", err)
}
loaded, err := store.Load(ctx, "")
if err != nil {
t.Fatalf("Load with empty provider URL failed: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials with empty provider URL")
}
})
}
// TestFileCredentialsStore_DefaultPath tests default path behavior
func TestFileCredentialsStore_DefaultPath(t *testing.T) {
t.Parallel()
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore("", logger)
// Just verify we can create with empty path and it has a default
if store.basePath() == "" {
t.Error("Expected default base path")
}
}
// TestRedisCredentialsStore_WithMemoryCache tests Redis store with in-memory cache
func TestRedisCredentialsStore_WithMemoryCache(t *testing.T) {
t.Parallel()
// Create an in-memory cache for testing
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
testCreds := &ClientRegistrationResponse{
ClientID: "redis-test-client",
ClientSecret: "redis-test-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "redis-test-token",
RedirectURIs: []string{"https://app.example.com/callback"},
}
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
}
// TestRedisCredentialsStore_TTLFromExpiry tests TTL calculation
func TestRedisCredentialsStore_TTLFromExpiry(t *testing.T) {
t.Parallel()
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
t.Run("expired credentials should fail", func(t *testing.T) {
expiredCreds := &ClientRegistrationResponse{
ClientID: "expired-client",
ClientSecret: "expired-secret",
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), // Already expired
}
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
if err == nil {
t.Error("Expected error for expired credentials")
}
})
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
creds := &ClientRegistrationResponse{
ClientID: "no-expiry-client",
ClientSecret: "no-expiry-secret",
ClientSecretExpiresAt: 0, // No expiry
}
err := store.Save(ctx, "https://noexpiry.example.com", creds)
if err != nil {
t.Fatalf("Failed to save credentials without expiry: %v", err)
}
})
}
// TestRedisCredentialsStore_InvalidInput tests error handling
func TestRedisCredentialsStore_InvalidInput(t *testing.T) {
t.Parallel()
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
}
// TestDCRStorageFactory tests the factory function
func TestDCRStorageFactory(t *testing.T) {
t.Parallel()
logger := GetSingletonNoOpLogger()
t.Run("nil config returns error", func(t *testing.T) {
_, err := NewDCRCredentialsStore(nil, nil, logger)
if err == nil {
t.Error("Expected error for nil config")
}
})
t.Run("file backend creates file store", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "file",
CredentialsFile: "/tmp/test-creds.json",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create file store: %v", err)
}
if store == nil {
t.Error("Expected store but got nil")
}
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore")
}
})
t.Run("redis backend without cache manager returns error", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "redis",
}
_, err := NewDCRCredentialsStore(config, nil, logger)
if err == nil {
t.Error("Expected error for redis backend without cache manager")
}
})
t.Run("auto backend without redis falls back to file", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "auto",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create auto store: %v", err)
}
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore for auto without redis")
}
})
t.Run("unknown backend returns error", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "unknown",
}
_, err := NewDCRCredentialsStore(config, nil, logger)
if err == nil {
t.Error("Expected error for unknown backend")
}
})
t.Run("empty backend defaults to auto", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create store with empty backend: %v", err)
}
// Should default to file (auto without redis)
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore for empty backend")
}
})
}
// TestDynamicClientRegistrar_WithStore tests registrar with store
func TestDynamicClientRegistrar_WithStore(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
}
registrar := NewDynamicClientRegistrarWithStore(
nil, // httpClient
logger,
config,
"https://auth.example.com",
store,
)
if registrar == nil {
t.Fatal("Expected registrar but got nil")
}
if registrar.store == nil {
t.Error("Expected store to be set")
}
// Test SetStore
newStore := NewFileCredentialsStore(filepath.Join(tempDir, "new.json"), logger)
registrar.SetStore(newStore)
if registrar.store != newStore {
t.Error("SetStore did not update the store")
}
}
// TestDynamicClientRegistrar_CredentialsFromStore tests loading from store
func TestDynamicClientRegistrar_CredentialsFromStore(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
providerURL := "https://auth.example.com"
ctx := context.Background()
// Pre-save credentials
testCreds := &ClientRegistrationResponse{
ClientID: "pre-saved-client",
ClientSecret: "pre-saved-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
}
if err := store.Save(ctx, providerURL, testCreds); err != nil {
t.Fatalf("Failed to pre-save credentials: %v", err)
}
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
}
registrar := NewDynamicClientRegistrarWithStore(
nil,
logger,
config,
providerURL,
store,
)
// Test loading via the internal method
loaded, err := registrar.loadCredentialsFromStore(ctx)
if err != nil {
t.Fatalf("Failed to load from store: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != "pre-saved-client" {
t.Errorf("ClientID mismatch: got %s", loaded.ClientID)
}
}
// TestFileCredentialsStore_CorruptedFile tests handling of corrupted files
func TestFileCredentialsStore_CorruptedFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
// Write corrupted JSON
filePath := store.getFilePath(providerURL)
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
t.Fatalf("Failed to write corrupted file: %v", err)
}
// Should return error for corrupted file
_, err := store.Load(ctx, providerURL)
if err == nil {
t.Error("Expected error for corrupted JSON")
}
}
// TestFileCredentialsStore_DirectoryCreation tests auto directory creation
func TestFileCredentialsStore_DirectoryCreation(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(deepPath, logger)
ctx := context.Background()
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "https://example.com", creds)
if err != nil {
t.Fatalf("Failed to save with nested directory: %v", err)
}
loaded, err := store.Load(ctx, "https://example.com")
if err != nil {
t.Fatalf("Failed to load after nested directory creation: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials from nested directory")
}
}
+427
View File
@@ -0,0 +1,427 @@
# Auth0 Audience Validation Guide
## Overview
This guide explains how to configure audience validation for Auth0 and other OIDC providers that support custom API audiences. It covers three common Auth0 scenarios and how to configure the middleware for maximum security.
## Table of Contents
1. [Understanding Audiences](#understanding-audiences)
2. [The Three Auth0 Scenarios](#the-three-auth0-scenarios)
3. [Configuration Options](#configuration-options)
4. [Security Recommendations](#security-recommendations)
5. [Troubleshooting](#troubleshooting)
---
## Understanding Audiences
### What is an Audience?
The **audience** (`aud`) claim in a JWT identifies the intended recipient of the token. Per OAuth 2.0 and OIDC specifications:
- **ID Tokens**: MUST have `aud = client_id` (per OIDC Core 1.0 spec)
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
### Why Does This Matter?
Audience validation rejects access tokens whose `aud` claim does not match the
expected audience, blocking the trivial form of token confusion where a token
issued for API A is presented to API B. (Defence in depth — pair with
short-lived tokens, rotation, and per-API client credentials.)
---
## The Three Auth0 Scenarios
### Scenario 1: Custom API Audience ✅ **RECOMMENDED**
**Configuration:**
```yaml
audience: "https://my-api.example.com" # Your API identifier from Auth0
```
**What Happens:**
1. Authorization request includes `audience` parameter
2. Auth0 issues:
- **ID Token**: `aud = client_id`
- **Access Token**: `aud = ["https://issuer/userinfo", "https://my-api.example.com"]`
3. Middleware validates:
- ID tokens against `client_id`
- Access tokens against custom audience
**Result:** ✅ Fully secure, OIDC compliant
---
### Scenario 2: Default Audience (No Custom API) ⚠️ **USE WITH CAUTION**
**Configuration:**
```yaml
# audience not specified (defaults to client_id)
```
**What Happens:**
1. Authorization request WITHOUT `audience` parameter
2. Auth0 issues:
- **ID Token**: `aud = client_id`
- **Access Token**: `aud = ["https://issuer/userinfo", "default_api"]` (no `client_id`)
3. Access token validation fails (audience mismatch)
4. Middleware falls back to ID token validation
**Security Warning:**
```
⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!
⚠️ This could allow tokens intended for different APIs to grant access
⚠️ Set strictAudienceValidation=true to enforce proper audience validation
⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74
```
**Recommended Fix:**
```yaml
strictAudienceValidation: true # Reject sessions with audience mismatch
```
**Result:**
- Default: ⚠️ Works but logs security warnings
- With strict mode: ✅ Secure (rejects mismatched tokens)
---
### Scenario 3: Opaque Access Tokens ✅ **SUPPORTED**
**Configuration:**
```yaml
allowOpaqueTokens: true # Enable opaque token support
requireTokenIntrospection: true # Require introspection (recommended)
```
**What Happens:**
1. Auth0 issues opaque (non-JWT) access token
2. Middleware detects opaque token (not 3 parts separated by dots)
3. Uses OAuth 2.0 Token Introspection (RFC 7662) to validate
4. Falls back to ID token if introspection unavailable (unless `requireTokenIntrospection=true`)
**Requirements:**
- Provider must support `introspection_endpoint` in OIDC discovery
- Client must have introspection permissions
**Result:** ✅ Secure with introspection, ⚠️ risky without
---
## Configuration Options
### Audience Settings
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `audience` | string | `client_id` | Expected audience for access tokens |
**Example:**
```yaml
# .traefik.yml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
audience: "https://my-api.example.com"
```
---
### Security Mode Settings
#### `strictAudienceValidation`
**Type:** boolean
**Default:** `false`
**Recommended:** `true` for production
**What it does:**
- When `true`: On audience mismatch, the middleware does **not** silently fall back to ID-token validation. It tries to refresh the access token first; if no refresh token is present (or refresh fails), the user is re-authenticated.
- When `false`: Logs warnings and falls back to ID-token validation (backward compatible).
**Example:**
```yaml
strictAudienceValidation: true
```
**When to use:**
- ✅ Always use in production environments
- ✅ When you have custom API audiences configured in Auth0
- ⚠️ May break existing deployments relying on Scenario 2 behavior
---
#### `allowOpaqueTokens`
**Type:** boolean
**Default:** `false`
**What it does:**
- When `true`: Accepts opaque (non-JWT) access tokens
- When `false`: Only accepts JWT access tokens
**Example:**
```yaml
allowOpaqueTokens: true
```
**When to use:**
- ✅ When Auth0 issues opaque tokens (no default API configured)
- ✅ When using Auth0 Management API tokens
- ⚠️ Requires introspection endpoint for security
---
#### `requireTokenIntrospection`
**Type:** boolean
**Default:** `false`
**Recommended:** `true` when `allowOpaqueTokens=true`
**What it does:**
- When `true`: Rejects opaque tokens if introspection fails or endpoint unavailable
- When `false`: Falls back to ID token validation for opaque tokens
**Example:**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
**When to use:**
- ✅ Always use when `allowOpaqueTokens=true` for maximum security
- ⚠️ Requires provider to expose introspection endpoint
---
## Security Recommendations
### Recommended Configuration for Auth0
**For APIs with custom audiences (Scenario 1):**
```yaml
audience: "https://my-api.example.com"
strictAudienceValidation: true
allowOpaqueTokens: false
```
**For default Auth0 setup (Scenario 2):**
```yaml
# Don't set audience (defaults to client_id)
strictAudienceValidation: true # Enforce proper configuration
```
**For opaque tokens (Scenario 3):**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
strictAudienceValidation: true
```
### Security Best Practices
1. ✅ **Always set `strictAudienceValidation: true` in production**
2. ✅ **Configure custom API audiences in Auth0 dashboard**
3. ✅ **Use `requireTokenIntrospection: true` if accepting opaque tokens**
4. ✅ **Monitor logs for security warnings**
5. ❌ **Don't rely on Scenario 2 fallback behavior**
---
## Troubleshooting
### "Access token validation failed due to audience mismatch"
**Symptom:**
```
⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch
```
**Cause:** Access token audience doesn't match configured audience
**Solutions:**
1. **Configure correct audience:**
```yaml
audience: "https://your-api-identifier" # From Auth0 API settings
```
2. **Update Auth0 authorization request:**
- Ensure `audience` parameter is included in authorize URL
- Middleware automatically adds this when `audience != client_id`
3. **Accept the behavior (not recommended):**
```yaml
strictAudienceValidation: false # Logs warnings but allows
```
---
### "Opaque token detected but allowOpaqueTokens=false"
**Symptom:**
```
⚠️ Opaque access token detected but allowOpaqueTokens=false
```
**Cause:** Auth0 issued non-JWT access token but middleware not configured to accept them
**Solutions:**
1. **Enable opaque tokens:**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
2. **Configure Auth0 to issue JWT access tokens:**
- Create an API in Auth0 dashboard
- Set API identifier as `audience` in configuration
---
### "Introspection endpoint not available"
**Symptom:**
```
⚠️ Opaque tokens enabled but no introspection endpoint available from provider
```
**Cause:** Auth0 provider metadata doesn't include `introspection_endpoint`
**Solutions:**
1. **Check provider discovery:**
```bash
curl https://YOUR_DOMAIN/.well-known/openid-configuration
```
Look for `introspection_endpoint`
2. **Disable required introspection (less secure):**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: false # Falls back to ID token
```
3. **Use JWT access tokens instead** (recommended)
---
### "Token introspection required but endpoint not available"
**Symptom:**
```
❌ SECURITY: Opaque token rejected (introspection required but failed)
```
**Cause:** `requireTokenIntrospection=true` but provider doesn't support it
**Solutions:**
1. **Disable required introspection:**
```yaml
requireTokenIntrospection: false
```
2. **Configure Auth0 to issue JWT tokens** (better solution)
---
## Advanced Topics
### Token Type Detection
The middleware uses a sophisticated 6-step detection algorithm:
1. **RFC 9068 `typ` header**: `at+jwt` → Access Token
2. **Explicit type claims**: `token_use`, `token_type`
3. **`scope` claim**: Present → Access Token
4. **`nonce` claim**: Present → ID Token (OIDC spec)
5. **Audience check**: `aud == client_id` only → ID Token
6. **Default**: Access Token
### OAuth 2.0 Token Introspection (RFC 7662)
When opaque tokens are detected:
1. Middleware calls provider's `introspection_endpoint`
2. Authenticates using client credentials
3. Receives response with `active` status and claims
4. Caches result for 5 minutes (configurable via TTL)
5. Validates expiration, not-before, and audience if present
**Cache behavior:**
- Cache key: Token hash
- TTL: 5 minutes; if the token's `exp` is sooner, the cache entry expires at `exp` instead. Tokens without `exp` use the flat 5-minute TTL.
- Reduces introspection requests for frequently used tokens
---
## Reference Links
- [GitHub Issue #74](https://github.com/lukaszraczylo/traefikoidc/issues/74) - Original Auth0 audience discussion
- [OIDC Core 1.0 Spec](https://openid.net/specs/openid-connect-core-1_0.html) - ID Token requirements
- [OAuth 2.0 RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) - OAuth 2.0 specification
- [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662) - OAuth 2.0 Token Introspection
- [RFC 9068](https://datatracker.ietf.org/doc/html/rfc9068) - JWT Access Token Profile
- [Auth0 API Authorization](https://auth0.com/docs/secure/tokens/access-tokens) - Auth0 audience documentation
---
## Migration Guide
### From Previous Versions
**If you're upgrading from a version without these features:**
1. **No action required for default behavior** - backward compatible
2. **Recommended: Enable strict mode gradually**
```yaml
# Step 1: Enable and monitor logs
strictAudienceValidation: false # Default
# Step 2: After confirming no warnings, enable
strictAudienceValidation: true
```
3. **For opaque tokens: Enable explicitly**
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
```
### Testing Your Configuration
1. **Check logs for warnings:**
```bash
# Look for Scenario 2 warnings
grep "SCENARIO 2 DETECTED" /var/log/traefik.log
# Look for opaque token warnings
grep "Opaque" /var/log/traefik.log
```
2. **Test with curl:**
```bash
# Get token from Auth0
ACCESS_TOKEN="your_access_token"
# Test request
curl -H "Authorization: Bearer $ACCESS_TOKEN" \
https://your-app.example.com/api
```
3. **Monitor for security warnings in production logs**
---
## Support
For issues or questions:
- GitHub Issues: https://github.com/lukaszraczylo/traefikoidc/issues
- Security issues: See SECURITY.md for responsible disclosure
---
**Last Updated:** 2025-01-09
**Version:** 0.7.8+
+250
View File
@@ -0,0 +1,250 @@
# Bearer Token (M2M) Authentication
Opt-in path that lets API clients present `Authorization: Bearer <jwt>` to
authenticate without going through the cookie-based OIDC redirect flow.
Designed for machine-to-machine (M2M) traffic — services calling other
services with tokens minted by your OIDC provider.
The bearer path lives next to the cookie path: both go through the same
post-auth pipeline (`forwardAuthorized`) that injects identity headers,
checks `allowedRolesAndGroups`, applies security headers, and forwards to
the backend. The only thing that differs is how the principal is established
for that single request.
## Quick start
```yaml
enableBearerAuth: true
audience: https://api.example.com # REQUIRED when bearer is enabled
clientID: my-api-client-id
providerURL: https://issuer.example.com
sessionEncryptionKey: <32+-byte secret>
callbackURL: /oauth2/callback
```
That is the minimum. Everything else has a secure default.
## Obtaining bearer tokens from your OIDC provider
The middleware only **validates** bearer tokens — minting them is the IdP's job. For M2M traffic the canonical mint flow is OAuth 2.0 **`client_credentials`** (RFC 6749 §4.4); some providers require **JWT bearer assertion** (RFC 7523) instead.
```
┌────────────┐ POST /token ┌──────────┐
│ client │ ───────────────────────────────►│ IdP │
│ (service) │ grant_type=client_credentials │ /token │
│ │ client_id=… │ │
│ │ client_secret=… (or JWT) │ │
│ │ audience=https://api.… ←── critical │
│ │ scope=api:read … │
│ │ ◄───────────────────────────────│ │
│ │ access_token (JWT) │ │
└────────────┘ └──────────┘
│ GET /protected
│ Authorization: Bearer <access_token>
Your service (behind Traefik + this plugin)
```
The IdP returns a JWT signed by the same JWKs the middleware already trusts (it discovers them from `providerURL`/.well-known). On the first protected request, the middleware verifies signature + issuer + **audience** + `exp` + identifier claim, then forwards downstream with `X-Forwarded-User` set.
### Minimal worked example (Auth0-shape)
```bash
# 1. Mint a token
curl -s -X POST https://issuer.example.com/oauth/token \
-H 'Content-Type: application/json' \
-d '{
"grant_type": "client_credentials",
"client_id": "your-m2m-client-id",
"client_secret": "your-m2m-client-secret",
"audience": "https://api.example.com",
"scope": "api:read api:write"
}'
# → {"access_token":"eyJhbGciOiJSUzI1NiIs…","token_type":"Bearer","expires_in":86400,…}
# 2. Use it
curl -H 'Authorization: Bearer eyJhbGciOiJSUzI1NiIs…' https://api.example.com/protected
```
The `audience` field in the token request **must match** the `audience` you configured on the middleware. Mismatch → 401 with `Bearer error="invalid_token"`.
### Per-provider quick reference
| Provider | Grant | Token endpoint | Audience parameter | Notes |
|---|---|---|---|---|
| **Auth0** | `client_credentials` | `https://TENANT.auth0.com/oauth/token` | `audience=<your API identifier>` | Register an "API" + "Machine to Machine Application" authorised against that API. Without `audience` you get an opaque /userinfo token, which the bearer path rejects. See `docs/AUTH0_AUDIENCE_GUIDE.md`. |
| **Okta** | `client_credentials` | `https://TENANT.okta.com/oauth2/default/v1/token` | Configured in the authorization server; default `aud` is the auth-server URL | Service app must enable the `client_credentials` flow and be granted the requested scopes. |
| **Keycloak** | `client_credentials` | `https://kc/realms/REALM/protocol/openid-connect/token` | Configure an "Audience" mapper on a client scope, or use `client_id` as the audience | Client must have `serviceAccountsEnabled: true` plus role mappings. |
| **Entra ID / Azure AD** | `client_credentials` (v2.0 endpoint) | `https://login.microsoftonline.com/TENANT/oauth2/v2.0/token` | Pass `scope=<App ID URI>/.default`; `aud` ends up being the API's App ID URI | Requires an App Registration + API permissions + admin consent. **Use the v2.0 endpoint** — v1 issues Microsoft-proprietary access tokens that are opaque to non-Microsoft clients. |
| **AWS Cognito** | `client_credentials` | `https://YOUR_DOMAIN.auth.REGION.amazoncognito.com/oauth2/token` | Scopes from a "Resource Server" attached to your User Pool | App client must have `client_credentials` flow enabled. Use HTTP **Basic** auth header for `client_id:client_secret`. |
| **GitLab** | `client_credentials` | `https://gitlab.com/oauth/token` | Audience matches the GitLab issuer | Rarely used for protecting external APIs; better suited for GitLab's own resources. |
| **Google** | **JWT bearer (RFC 7523)***not* `client_credentials` | `https://oauth2.googleapis.com/token` | Signed assertion JWT carries `aud=https://oauth2.googleapis.com/token`; resulting access token is **opaque** unless you specifically request a Google-issued JWT for your API | Google service-account flow is not the best fit for this middleware (opaque tokens are rejected on the bearer path). Run Auth0 / Okta / Keycloak in front, or use ID-token-based flows on the cookie path. |
### RFC 7523 (JWT bearer assertion) — secretless alternative
When shared secrets are forbidden (FAPI, internal compliance), swap `client_secret` for a signed JWT assertion:
```
POST /token
grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer
assertion=<JWT signed by the client's private key>
```
The assertion JWT carries `iss=<client_id>`, `sub=<client_id>`, `aud=<token endpoint>`, `exp`. The IdP verifies the signature against a public key you've pre-registered and returns an access token.
This middleware already supports JWT assertions on the *middleware → IdP* hop via `clientAuthMethod: private_key_jwt` (see `docs/CONFIGURATION.md`). For the *client → IdP* hop, the same pattern applies — the client signs its own assertion.
### Operational notes
- **Token TTL is typically 124 hours.** Clients should refresh on `401`, not on a polling timer — saves the IdP.
- **Cache and reuse tokens.** The middleware caches verified tokens too, so repeated presentations are cheap. Clients SHOULD reuse a token until ~80 % of `expires_in`.
- **JWKS rotation is transparent.** The middleware auto-refreshes its JWKS cache when the IdP rotates keys. Clients don't need to do anything.
- **Revocation is generally not per-token** with `client_credentials`. If you need real-time revocation, set `requireTokenIntrospection: true` on the middleware and the IdP is consulted on every cache miss.
- **`scope` vs `audience`.** Scope says *what the client may do*; audience says *which service the token is for*. The middleware enforces audience; the backend service should enforce scope.
- **Secret hygiene.** Store `client_secret` in a secrets manager (Vault, AWS Secrets Manager, Kubernetes `Secret`). For higher assurance, switch the client to `private_key_jwt` (no shared secret at all).
### Quickest validation loop
```bash
# 1. Mint
TOKEN=$(curl -s -X POST https://issuer.example.com/oauth/token \
-H 'Content-Type: application/json' \
-d '{"grant_type":"client_credentials","client_id":"…","client_secret":"…","audience":"https://api.example.com"}' \
| jq -r .access_token)
# 2. Inspect claims to confirm aud/iss/exp match the middleware config
echo "$TOKEN" | cut -d. -f2 | base64 -d 2>/dev/null | jq
# 3. Hit the protected route
curl -i -H "Authorization: Bearer $TOKEN" https://api.example.com/protected
```
`HTTP/1.1 200` with `X-Forwarded-User` on the backend confirms the loop works end-to-end. `401` with `WWW-Authenticate: Bearer error="invalid_token"` plus a middleware debug log explaining the rejection (audience mismatch, ID token presented, `iat` outside the 24h window, etc.) confirms the hardening is firing as designed.
## Threat model and design rules
Bearer authentication has materially different security properties from
cookie sessions: no `HttpOnly`/`Secure`/`SameSite` shielding, the token is
visible in headers and logs, and it's easier to exfiltrate. The bearer path
treats every one of these as a first-class concern.
| Property | Behaviour | Why |
|---|---|---|
| Default state | `enableBearerAuth=false` | Bearer is opt-in; existing deployments observe no change. |
| Audience | **Mandatory.** Startup fails if `audience` is empty when bearer is enabled. | Eliminates the "token issued for service B accepted by service A" confusion attack. |
| Token format | JWT only (3 segments, JOSE-encoded). Opaque tokens are not accepted on the bearer path. | Matches the validation pipeline; opaque tokens require introspection only and bypass JWT-specific defences. |
| `alg` allowlist | Hard-pinned asymmetric: `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. Checked **before** any JWKS fetch. | Denies `alg=none` and `alg=HS*` probes; prevents attacker noise from amplifying into JWKS round-trips. |
| `kid` hardening | Max 256 bytes; charset `[A-Za-z0-9._\-=]`. Checked **before** JWKS fetch. | Prevents cache-key explosion / pathological-`kid` JWKS amplification. |
| Token type | ID tokens are explicitly rejected (`nonce` claim, `typ: at+jwt`, `token_use=id`, scope/aud heuristics — reuses the existing `detectTokenType` helper). | ID tokens are not API credentials; treating them as such is classic token confusion. |
| Multi-audience | When `aud` is an array of length > 1, the token must carry `azp == clientID`. | OIDC §2 hardening against tokens minted for one client being replayed by another. |
| `iat` upper-age | Rejects tokens older than `maxTokenAgeSeconds` (default 24h). | Bounds clock-manipulation / forever-token abuse, even if `exp` is far in the future. |
| Identifier claim | `bearerIdentifierClaim` (default `"sub"`). Resolved value drives `X-Forwarded-User`. | Decoupled from the cookie path's `UserIdentifierClaim` (default `email`) so the M2M flow can never accidentally trust an unverified email. |
| Identifier sanitisation | Length cap (`maxIdentifierLength`, default 256). Rejects control chars, Unicode bidi-overrides (U+202AU+202E, U+2066U+2069), and the delimiters `, ; =`. | Defence in depth against downstream header injection / log injection / admin-UI spoofing. |
| JTI replay marking | Bearer path skips the JTI **Set** (so the same token can be reused until `exp`) but the **Get** stays active. | Allows legitimate bearer reuse without false-positive replay detection; revoked tokens (added to the blacklist by `RevokeToken`) still fail immediately. |
| Mixed bearer + cookie | **Cookie wins by default.** Flip to bearer-wins with `bearerOverridesCookie=true`. | Safer against browser/extension/proxy bearer injection scenarios. The cookie is the authoritative authenticator when present. |
| `Authorization` strip | `stripAuthorizationHeader=true` by default. | Keeps the raw token out of downstream services and their logs. |
| Excluded URLs | `Authorization` is stripped on excluded paths when `enableBearerAuth=true`. | Prevents bearer leakage into public health/metrics endpoint logs and prevents recon via excluded paths. |
| Per-IP throttle | After `bearerFailureThreshold` consecutive 401s from one source IP within `bearerFailureWindowSeconds`, further bearer requests from that IP return `429 Too Many Requests` + `Retry-After` for `bearerFailurePenaltySeconds`. | Limits offline-guessing-style attacks and protects the shared rate-limiter / JWKS endpoint. |
| Optional introspection | `requireTokenIntrospection=true` calls RFC 7662 introspection on every cache miss. Introspection result is cached briefly. Endpoint failure returns `503` (distinguishes infra outage from credential rejection). | Real-time revocation for high-assurance environments. Adds per-request IdP latency. |
| Response shape | `401 Unauthorized` with generic body. `WWW-Authenticate: Bearer error="invalid_token"` per RFC 6750 §3 (toggleable via `bearerEmitWWWAuthenticate`). `403` for roles/groups denial. `429` for throttle. `503` for introspection-endpoint outage. | Auditable from spec to code; reason categories never leak into the response body. |
| Logging | Failure reason + identifier hash (SHA-256 truncated to 8 hex chars) logged at debug. Raw tokens are never logged. | Audit trail without secrets-in-logs. |
## Configuration reference
| Field | Default | Description |
|---|---|---|
| `enableBearerAuth` | `false` | Master switch for the bearer path. |
| `audience` | (unset) | **Required** when `enableBearerAuth=true`. Reuses the existing global `audience` field. |
| `bearerIdentifierClaim` | `"sub"` | JWT claim used as the principal identifier. `"email"` is rejected at startup. |
| `stripAuthorizationHeader` | `true` | Remove the `Authorization` header before forwarding to the backend. Disable only when a downstream needs to re-verify the bearer. |
| `bearerEmitWWWAuthenticate` | `true` | Include `WWW-Authenticate: Bearer error="..."` on 401 responses (RFC 6750 §3). Disable to reduce recon signal. |
| `bearerOverridesCookie` | `false` | Cookie wins when both are present (default). Set `true` for the AWS/GCP/Kubernetes bearer-wins convention. |
| `maxTokenAgeSeconds` | `86400` | Upper bound on `iat` claim age (24h). Set `0` to disable the check (not recommended). |
| `maxIdentifierLength` | `256` | Length cap for the post-sanitisation identifier. |
| `bearerFailureThreshold` | `20` | Consecutive 401s from one IP that trip the throttle. |
| `bearerFailureWindowSeconds` | `60` | Rolling window over which 401s are counted. |
| `bearerFailurePenaltySeconds` | `60` | Duration of the 429 penalty box after the threshold trips. |
| `requireTokenIntrospection` | `false` | Call RFC 7662 introspection on every cache miss. Adds per-request IdP latency. |
## What the bearer path does NOT do
- **Human-user / browser flows.** The bearer path is M2M-only in this
iteration. Browser SPAs that want to attach a bearer to fetch calls work
if your backend treats them as machine clients, but the spec defaults are
tuned for service-to-service traffic.
- **Opaque access tokens.** Tokens must be JWTs. Introspection is a
revocation overlay on top of JWT verification, not a substitute for it.
- **`email_verified` enforcement.** The bearer path rejects `email` as the
identifier claim at startup precisely because `email_verified` is not
enforced in this iteration. Adding human-user bearer support is a
follow-up that must include this check.
- **mTLS / API keys.** Out of scope. The `principal` abstraction enables
adding these later as additional auth methods that produce a principal
for the shared `forwardAuthorized` pipeline.
- **SSE / WebSocket bypass with bearer.** Bypass paths keep their existing
cookie-only behaviour; bearer headers are ignored on those endpoints.
Documented limitation; widen by removing the bypass if you need bearer on
streaming endpoints.
## Operational guidance
- **Always set `strictAudienceValidation: true` when bearer is enabled.**
Startup logs a recommendation if you don't.
- **Set a tight `maxTokenAgeSeconds`** for environments where tokens are
expected to be minted frequently — the default 24h is conservative.
- **Enable `requireTokenIntrospection`** if your IdP supports it and
revocation latency matters. Bearer-path introspection caches results for
a short window per token.
- **Monitor 429s.** Sustained 429 traffic indicates either a buggy client
loop or an active credential-stuffing attempt. The throttle is your
primary signal for both.
- **`stripAuthorizationHeader=false` extends the token's blast radius** to
every downstream service that sees the request. Treat those services'
logs as token stores.
- **Bearer reuse is normal.** Don't enable per-token rate limiting; that's
what `bearerFailureThreshold` is for (per-IP, not per-token).
- **Cookie-wins is the safer default.** Only flip `bearerOverridesCookie`
if you control all clients and have audited that none of them present a
cookie alongside a bearer they don't intend to authenticate with.
## Failure response matrix
| Trigger | Status | Body | `WWW-Authenticate` |
|---|---|---|---|
| Empty bearer after prefix | 401 | `Unauthorized` | `Bearer error="invalid_request"` |
| Token over `MaxLength` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Not a 3-segment JWT | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Disallowed `alg` (e.g. none, HS*) | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Missing / oversized / bad-charset `kid` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Signature / issuer / audience / `exp` failure | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| `iat` older than `maxTokenAgeSeconds` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Multi-audience token without matching `azp` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Detected as ID token | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| JTI blacklisted (revoked) | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Introspection reports `active=false` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Introspection endpoint failure | 503 | `Service Unavailable` | (none) |
| Identifier claim missing / empty | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Identifier fails sanitisation | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Per-IP failure threshold tripped | 429 | `Too Many Requests` | (none); `Retry-After: <bearerFailurePenaltySeconds>` |
| Roles / groups not allowed | 403 | `Access denied` | (none) |
## Known follow-ups (deferred)
These are documented as future work, not blockers:
- **Human-user bearer with `email_verified` enforcement.** Requires
decoupling the email-claim guard from the startup rejection and adding a
per-request `email_verified=true` check.
- **Introspection respects `client_assertion`.** The existing introspection
helper uses `client_secret_basic` only; operators on `private_key_jwt`
will see introspection silently use basic auth.
- **Per-route bearer configuration.** Single middleware-wide setting in this
iteration.
## References
- [PR design spec](superpowers/specs/2026-05-18-bearer-token-auth-design.md) — full design rationale, alternatives considered, and per-section sign-off history.
- [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) — Bearer Token Usage.
- [RFC 7662](https://www.rfc-editor.org/rfc/rfc7662) — OAuth 2.0 Token Introspection.
- [RFC 9068](https://www.rfc-editor.org/rfc/rfc9068) — JWT Profile for OAuth 2.0 Access Tokens.
+1
View File
@@ -0,0 +1 @@
traefikoidc.raczylo.com
+683
View File
@@ -0,0 +1,683 @@
# Configuration Reference
Complete reference for all Traefik OIDC middleware configuration options.
## Table of Contents
- [Required Parameters](#required-parameters)
- [Client Authentication](#client-authentication)
- [Optional Parameters](#optional-parameters)
- [Security Options](#security-options)
- [Session Management](#session-management)
- [Access Control](#access-control)
- [Headers Configuration](#headers-configuration)
- [Security Headers](#security-headers)
- [Scope Configuration](#scope-configuration)
- [Advanced Options](#advanced-options)
- [Standalone binary (oidcgate)](#standalone-binary-oidcgate)
---
## Required Parameters
| Parameter | Type | Description | Example |
|-----------|------|-------------|---------|
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
| `clientSecret` | string | OAuth 2.0 client secret. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`. Optional when `clientAuthMethod: private_key_jwt`. | `your-client-secret` |
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
### Basic Configuration Example
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id.apps.googleusercontent.com
clientSecret: your-client-secret
sessionEncryptionKey: your-32-byte-encryption-key-here
callbackURL: /oauth2/callback
```
---
## Client Authentication
The middleware supports three client authentication methods at the token and
revocation endpoints. The default is `client_secret_post` (current behavior);
`private_key_jwt` is opt-in and backwards compatible.
| Method | Default | Description |
|--------|---------|-------------|
| `client_secret_post` | yes | `client_id` + `client_secret` in the request body. |
| `client_secret_basic` | no | RFC 6749 §2.3.1 — `client_id` + `client_secret` in the `Authorization: Basic` header (form-urlencoded then base64); not in the body. |
| `private_key_jwt` | no | RFC 7523 §2.2 — plugin signs a short-lived JWT with a private key and sends it as `client_assertion`. |
Select via `clientAuthMethod`:
```yaml
clientAuthMethod: private_key_jwt
```
### client_secret_post
Default. The plugin sends `client_id` and `client_secret` as form parameters
in the token / revocation request body. No additional configuration required.
### private_key_jwt
Asymmetric client authentication per
[RFC 7523 §2.2](https://www.rfc-editor.org/rfc/rfc7523). Use this when your
IdP enforces short secret TTLs, when policy mandates secretless clients, or
when you want to avoid distributing a shared secret to the proxy.
For each token / revocation request the plugin builds a JWS with:
- `iss` = `sub` = `clientID`
- `aud` = token endpoint URL
- `iat` = now, `exp` = now + 60s
- `jti` = random hex per request
- `kid` header = `clientAssertionKeyID`
**Required fields:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `clientAuthMethod` | string | `client_secret_post` | Set to `private_key_jwt`. |
| `clientAssertionPrivateKey` | string | none | Inline PEM private key. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8, PKCS#1, and SEC1 formats accepted. |
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk. Mutually exclusive with `clientAssertionPrivateKey`. |
| `clientAssertionKeyID` | string | none | `kid` header inserted in the JWS. Must match the public key registered with the IdP. |
| `clientAssertionAlg` | string | `RS256` | One of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `ES256`, `ES384`, `ES512`. |
When `clientAuthMethod: private_key_jwt`, `clientSecret` is optional.
**Example — inline PEM:**
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
spec:
plugin:
traefikoidc:
providerURL: https://idp.example.com
clientID: my-client-id
sessionEncryptionKey: your-32-byte-encryption-key-here
callbackURL: /oauth2/callback
clientAuthMethod: private_key_jwt
clientAssertionKeyID: key-2026-01
clientAssertionAlg: RS256
clientAssertionPrivateKey: |
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7VJTUt9Us8cKj
MZj4ev7QnMa1mYV3Kx1jRkH5YwXQ7N2J2j8K5pP6h0oZmXq1yQv4r8wZb3sH9D2k
... (truncated) ...
-----END PRIVATE KEY-----
```
**Example — key on disk:**
```yaml
clientAuthMethod: private_key_jwt
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
clientAssertionKeyID: key-2026-01
clientAssertionAlg: RS256
```
**Generating an RS256 key with OpenSSL:**
```bash
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 \
-out client-key.pem
openssl rsa -in client-key.pem -pubout -out client-pub.pem
```
Register `client-pub.pem` (or its JWK form) with your IdP under the same
`kid` you set in `clientAssertionKeyID`.
**Notes:**
- The private key is parsed once at plugin startup. Key rotation requires a
Traefik reload.
- Assertion lifetime is fixed at 60 seconds.
- A fresh random `jti` is generated per request.
- The `aud` claim is the token endpoint URL (from discovery).
- Tracking issue:
[#135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
### client_secret_basic
Per [RFC 6749 §2.3.1][rfc6749-2-3-1], the plugin sends the client credentials
in an `Authorization: Basic` header instead of the body. Both halves
(`client_id`, `client_secret`) are form-urlencoded individually, joined with
a colon, then base64-encoded. Use this when your IdP requires Basic auth at
the token endpoint and rejects credentials in the body.
```yaml
clientAuthMethod: client_secret_basic
clientID: your-client-id
clientSecret: your-client-secret
```
[rfc6749-2-3-1]: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
---
## Optional Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
| `rateLimit` | int | `100` | Maximum requests per second |
| `excludedURLs` | []string | none | Paths that bypass authentication |
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
| `clientAuthMethod` | string | `client_secret_post` | Client authentication method at token/revocation endpoints. One of `client_secret_post`, `client_secret_basic`, `private_key_jwt`. See [Client Authentication](#client-authentication). |
| `clientAssertionPrivateKey` | string | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8 / PKCS#1 / SEC1. |
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk for `private_key_jwt`. Mutually exclusive with `clientAssertionPrivateKey`. |
| `clientAssertionKeyID` | string | none | `kid` header for `private_key_jwt` assertions. Required when `clientAuthMethod: private_key_jwt`. |
| `clientAssertionAlg` | string | `RS256` | Signing algorithm for `private_key_jwt`. One of `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
### TLS Termination at Load Balancer
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
the correct default behind any TLS-terminating load balancer (AWS ALB, Google
Cloud LB, Azure App Gateway) — `X-Forwarded-Proto` cannot be trusted (ALB may
overwrite it).
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
dev). Otherwise leave it at default.
### Streaming Endpoints (SSE and WebSocket)
The middleware automatically bypasses the OIDC redirect for two request kinds
that browsers cannot follow a 302 on:
| Bypass | Triggered by |
|--------|--------------|
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
These requests do **not** require any explicit configuration — they are
handled implicitly. However, the bypass is **not** unauthenticated:
- A valid, encrypted session cookie is required. Requests without one are
rejected (the connection cannot proceed to the backend).
- The session cookie is sealed with `sessionEncryptionKey`, so the
`authenticated` flag cannot be forged.
- Validation is cookie-only — no JWK fetch / signature verification — so
streaming endpoints keep working when the OIDC provider is briefly
unavailable.
- The user identifier from the session is forwarded as `X-Forwarded-User`
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
For browser clients, the user must complete the normal OIDC flow on a
regular HTTP page first; the resulting session cookie is then reused on the
SSE / WebSocket connection.
---
## Security Options
### Audience Validation
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `audience` | string | `clientID` | Expected audience for access token validation |
| `strictAudienceValidation` | bool | `false` | Reject sessions with audience mismatch |
| `allowOpaqueTokens` | bool | `false` | Enable opaque token support via RFC 7662 |
| `requireTokenIntrospection` | bool | `false` | Require introspection for opaque tokens |
#### Production Security Configuration
```yaml
audience: "https://my-api.example.com"
strictAudienceValidation: true
```
#### Opaque Token Support
```yaml
allowOpaqueTokens: true
requireTokenIntrospection: true
strictAudienceValidation: true
```
### Other Security Options
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
### Bearer-token (M2M) authentication
Opt-in path that accepts `Authorization: Bearer <jwt>` instead of the cookie
session flow. M2M-only, default off, audience-mandatory. See
[docs/BEARER_AUTH.md](BEARER_AUTH.md) for the threat model and operational
guidance.
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enableBearerAuth` | bool | `false` | Master switch. Startup fails if true with empty `audience` or with `bearerIdentifierClaim=email`. |
| `bearerIdentifierClaim` | string | `"sub"` | JWT claim used as the principal identifier. `"email"` is rejected at startup. |
| `stripAuthorizationHeader` | bool | `true` | Strip `Authorization` from forwarded requests after successful bearer auth. |
| `bearerEmitWWWAuthenticate` | bool | `true` | Emit RFC 6750 `WWW-Authenticate: Bearer error="..."` hints on 401. |
| `bearerOverridesCookie` | bool | `false` | Cookie wins when both bearer and cookie are present (default). Set true for bearer-wins. |
| `maxTokenAgeSeconds` | int64 | `86400` | Upper bound on `iat` claim age (24h). 0 disables the check. |
| `maxIdentifierLength` | int | `256` | Length cap on the sanitised principal identifier. |
| `bearerFailureThreshold` | int | `20` | Consecutive 401s from one source IP that trip the throttle. |
| `bearerFailureWindowSeconds` | int | `60` | Rolling window for counting 401s. |
| `bearerFailurePenaltySeconds` | int | `60` | 429 + `Retry-After` duration after the threshold trips. |
---
## Session Management
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
| `cookieDomain` | string | auto-detected | Domain for session cookies |
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
### Multi-Subdomain Setup
```yaml
cookieDomain: .example.com # Share cookies across subdomains
```
### Multiple Middleware Instances
When running multiple middleware instances with different authorization requirements, use unique prefixes:
```yaml
# User authentication middleware
---
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-userauth
spec:
plugin:
traefikoidc:
cookiePrefix: "_oidc_userauth_"
sessionEncryptionKey: user-encryption-key-min-32-bytes
# ... other config
---
# Admin authentication middleware
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-adminauth
spec:
plugin:
traefikoidc:
cookiePrefix: "_oidc_adminauth_"
sessionEncryptionKey: admin-encryption-key-min-32-bytes
allowedUsers:
- admin@example.com
# ... other config
```
### Extended Session Duration
```yaml
sessionMaxAge: 604800 # 7 days
# Common values:
# 3600 - 1 hour (high security)
# 86400 - 1 day (default)
# 259200 - 3 days
# 604800 - 7 days
# 2592000 - 30 days
```
---
## Access Control
### User Restrictions
| Parameter | Type | Description |
|-----------|------|-------------|
| `allowedUserDomains` | []string | Restrict to specific email domains |
| `allowedUsers` | []string | Specific email addresses allowed |
| `allowedRolesAndGroups` | []string | Required roles or groups |
| `roleClaimName` | string | JWT claim for roles (default: `roles`) |
| `groupClaimName` | string | JWT claim for groups (default: `groups`) |
| `userIdentifierClaim` | string | Claim for user ID (default: `email`) |
### Domain Restriction
```yaml
allowedUserDomains:
- company.com
- subsidiary.com
```
### Specific User Access
```yaml
allowedUsers:
- user@example.com
- contractor@external.org
```
### Role-Based Access Control
```yaml
allowedRolesAndGroups:
- admin
- developer
roleClaimName: "https://myapp.com/roles" # For namespaced claims (Auth0)
```
### Access Control Logic
- If only `allowedUsers` is set: Only specified emails can access
- If only `allowedUserDomains` is set: Only specified domains can access
- If both are set: Access granted if email is in `allowedUsers` OR domain is in `allowedUserDomains`
- If neither is set: Any authenticated user can access
### Users Without Email (Azure AD)
For Azure AD service accounts or users without email:
```yaml
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
allowedUsers:
- "abc12345-6789-0abc-def0-123456789abc" # User object ID
```
---
## Headers Configuration
### Default Headers
The middleware sets these headers for downstream services:
| Header | Description |
|--------|-------------|
| `X-Forwarded-User` | User's email address |
| `X-User-Groups` | Comma-separated user groups |
| `X-User-Roles` | Comma-separated user roles |
| `X-Auth-Request-Redirect` | Original request URI |
| `X-Auth-Request-User` | User's email address |
| `X-Auth-Request-Token` | User's ID token |
### Minimal Headers Mode
For "431 Request Header Fields Too Large" errors:
```yaml
minimalHeaders: true # Only forwards X-Forwarded-User
```
### Custom Templated Headers
```yaml
headers:
- name: "X-User-Email"
value: "{{{{.Claims.email}}}}"
- name: "X-User-ID"
value: "{{{{.Claims.sub}}}}"
- name: "Authorization"
value: "Bearer {{{{.AccessToken}}}}"
- name: "X-User-Roles"
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
```
**Template Variables:**
- `{{.Claims.field}}` - ID token claims
- `{{.AccessToken}}` - Raw access token
- `{{.IdToken}}` - Raw ID token
- `{{.RefreshToken}}` - Raw refresh token
**Important:** Use double curly braces (`{{{{` and `}}}}`) to escape templates in YAML.
---
## Security Headers
### Security Profiles
| Profile | Use Case | Security Level |
|---------|----------|----------------|
| `default` | Standard web apps | High |
| `strict` | Maximum security | Very High |
| `development` | Local development | Medium |
| `api` | API endpoints | High |
| `custom` | Custom requirements | Configurable |
### Basic Configuration
```yaml
securityHeaders:
enabled: true
profile: "default"
```
### API with CORS
```yaml
securityHeaders:
enabled: true
profile: "api"
corsEnabled: true
corsAllowedOrigins:
- "https://your-frontend.com"
- "https://*.example.com"
corsAllowCredentials: true
```
### Custom Security Configuration
```yaml
securityHeaders:
enabled: true
profile: "custom"
# Content Security Policy
contentSecurityPolicy: "default-src 'self'; script-src 'self'"
# HSTS
strictTransportSecurity: true
strictTransportSecurityMaxAge: 31536000
strictTransportSecuritySubdomains: true
strictTransportSecurityPreload: true
# Frame and Content Protection
frameOptions: "DENY"
contentTypeOptions: "nosniff"
xssProtection: "1; mode=block"
referrerPolicy: "strict-origin-when-cross-origin"
# CORS
corsEnabled: true
corsAllowedOrigins: ["https://app.example.com"]
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
corsAllowedHeaders: ["Authorization", "Content-Type"]
corsAllowCredentials: true
corsMaxAge: 86400
# Custom Headers
customHeaders:
X-Custom-Header: "value"
# Server Identification
disableServerHeader: true
disablePoweredByHeader: true
```
### CORS Origin Patterns
```yaml
corsAllowedOrigins:
- "https://example.com" # Exact match
- "https://*.example.com" # Subdomain wildcard
- "http://localhost:*" # Port wildcard (development)
```
---
## Scope Configuration
### Default Behavior (Append Mode)
```yaml
scopes:
- roles
- custom_scope
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
```
### Override Mode
```yaml
overrideScopes: true
scopes:
- openid
- profile
- custom_scope
# Result: ["openid", "profile", "custom_scope"]
```
---
## Advanced Options
### Dynamic Client Registration (RFC 7591)
Dynamic Client Registration allows the middleware to automatically register itself with the OIDC provider, eliminating the need to manually create client credentials.
**Basic Configuration (Single Instance):**
```yaml
dynamicClientRegistration:
enabled: true
initialAccessToken: "your-token" # Optional, if provider requires it
persistCredentials: true
credentialsFile: "/tmp/oidc-credentials.json"
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
application_type: "web"
grant_types:
- "authorization_code"
- "refresh_token"
```
**Multi-Replica Deployment (Kubernetes):**
For Kubernetes deployments with multiple replicas, use Redis storage to share credentials across all instances and prevent registration race conditions:
```yaml
dynamicClientRegistration:
enabled: true
persistCredentials: true
storageBackend: "redis" # Share credentials via Redis
redisKeyPrefix: "myapp:dcr:" # Optional custom prefix
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
redis:
enabled: true
address: "redis:6379"
cacheMode: "redis"
```
**Storage Backend Options:**
| Backend | Description | Use Case |
|---------|-------------|----------|
| `file` | Store credentials in local file | Single instance deployments |
| `redis` | Store credentials in Redis | Multi-replica Kubernetes deployments |
| `auto` | Use Redis if available, fallback to file | Flexible deployments (default) |
### Multi-Replica Deployment
Without Redis, disable replay detection:
```yaml
disableReplayDetection: true
```
With Redis (recommended):
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
See [REDIS.md](REDIS.md) for complete Redis configuration.
---
## Kubernetes Secrets
Reference secrets instead of hardcoding sensitive values:
```yaml
providerURL: urn:k8s:secret:oidc-secret:ISSUER
clientID: urn:k8s:secret:oidc-secret:CLIENT_ID
clientSecret: urn:k8s:secret:oidc-secret:SECRET
```
Create the secret:
```bash
kubectl create secret generic oidc-secret \
--from-literal=ISSUER=https://accounts.google.com \
--from-literal=CLIENT_ID=your-client-id \
--from-literal=SECRET=your-client-secret \
-n traefik
```
---
## Environment Variable Naming
**Important:** Avoid using "API" as a substring in environment variable names when using `${VAR}` syntax in Traefik configuration. Traefik reserves `TRAEFIK_API_*` variables and the substring may cause conflicts.
```yaml
# Bad - may cause issues
sessionEncryptionKey: ${OIDC_SECRET_API}
# Good
sessionEncryptionKey: ${OIDC_SECRET_SVC}
```
---
## Standalone binary (oidcgate)
If you don't run Traefik, the same configuration shape documented above
works for the [`oidcgate`](OIDCGATE.md) standalone forward-auth daemon
under `cmd/oidcgate`. Three extra top-level keys (`listen`, `authPath`,
`startPath`) configure the daemon itself; everything else maps 1:1 onto
the `traefikoidc.Config` fields documented in this reference.
See [`docs/OIDCGATE.md`](OIDCGATE.md) for the full daemon guide including
nginx, Caddy, Traefik ForwardAuth, HAProxy and Envoy wiring snippets,
the `OIDCGATE_*` environment-variable inventory, the security posture
(X-Forwarded-Uri sanitisation, excludedURLs guardrail), and how to layer
M2M [bearer-token auth](BEARER_AUTH.md) on the same daemon.
+95
View File
@@ -0,0 +1,95 @@
# Dynamic Client Registration (RFC 7591)
The middleware can register itself with an OIDC provider at startup instead of
using a pre-provisioned `clientID` / `clientSecret`. Useful for multi-tenant
deployments, self-service integrations, and ephemeral environments.
## How it works
1. Middleware reads `registration_endpoint` from `.well-known/openid-configuration`.
2. If `clientID` is empty, it `POST`s `clientMetadata` to the registration endpoint.
3. Returned `client_id` / `client_secret` are cached, optionally persisted.
4. Subsequent requests use the registered credentials.
For multi-replica deployments, set `storageBackend: redis` so all replicas
share one client and avoid registration races.
## Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-dcr
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://your-oidc-provider.com
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
dynamicClientRegistration:
enabled: true
persistCredentials: true
storageBackend: redis # file | redis | auto
initialAccessToken: "" # optional, for protected endpoints
registrationEndpoint: "" # optional, override discovery
credentialsFile: /tmp/oidc-client-credentials.json
redisKeyPrefix: "dcr:creds:"
clientMetadata:
redirect_uris:
- https://app.example.com/oauth2/callback
client_name: My Application
application_type: web
grant_types: [authorization_code, refresh_token]
response_types: [code]
token_endpoint_auth_method: client_secret_basic
contacts: [admin@example.com]
```
## Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `enabled` | `false` | Enable DCR. |
| `persistCredentials` | `false` | Save returned credentials for reuse across restarts. |
| `storageBackend` | `auto` | `file`, `redis`, or `auto` (Redis if available, else file). |
| `credentialsFile` | `/tmp/oidc-client-credentials.json` | Path for file-backed storage. Mode `0600`. |
| `redisKeyPrefix` | (none — set explicitly) | Key prefix for Redis-backed storage. The code does not inject a default; if unset, keys have no prefix. `dcr:creds:` is a sensible convention. |
| `registrationEndpoint` | discovered | Override the discovered endpoint. |
| `initialAccessToken` | none | Bearer token for protected registration endpoints. |
| `clientMetadata.redirect_uris` | required | Callback URIs for the OAuth flow. |
| `clientMetadata.client_name` | none | Human-readable client name. |
| `clientMetadata.application_type` | `web` | `web` or `native`. |
| `clientMetadata.grant_types` | `[authorization_code, refresh_token]` | OAuth grant types. |
| `clientMetadata.response_types` | `[code]` | OAuth response types. |
| `clientMetadata.token_endpoint_auth_method` | `client_secret_basic` | `client_secret_basic`, `client_secret_post`, or `none`. |
| `clientMetadata.scope` | none | Space-separated scopes. |
| `clientMetadata.contacts` | none | Admin email addresses. |
| `clientMetadata.logo_uri` | none | Logo URL for consent screens. |
| `clientMetadata.client_uri` | none | Client homepage URL. |
| `clientMetadata.policy_uri` | none | Privacy policy URL. |
| `clientMetadata.tos_uri` | none | Terms of service URL. |
## Provider support
The middleware does not gate DCR by provider — if the provider exposes a
`registration_endpoint` in its discovery document (or you set
`registrationEndpoint` explicitly), DCR will attempt registration. The table
below is informational guidance based on each provider's published support.
| Provider | DCR | Notes |
|----------|-----|-------|
| Keycloak | Yes | Enable in realm settings. |
| Auth0 | Yes | Requires Management API token. |
| Okta | Yes | Enable Dynamic Client Registration in admin console. |
| Azure AD | Limited | Use App Registration API instead. |
| Google | No | Manual registration required. |
| AWS Cognito | No | Manual registration required. |
## Security notes
- Registration endpoints must be HTTPS (loopback excepted for local dev).
- Use `initialAccessToken` in production to gate registration.
- File-backed credentials use `0600`; protect the mount path.
- The plugin marks credentials invalid when within ~5 min of `client_secret_expires_at` but does **not** automatically re-register. If your provider sets a non-zero expiry, schedule manual rotation (delete the credentials file or Redis entry, restart) before that time.
+376
View File
@@ -0,0 +1,376 @@
# Development Guide
Guide for local development, testing, and contributing to the Traefik OIDC middleware.
## Table of Contents
- [Prerequisites](#prerequisites)
- [Local Development Setup](#local-development-setup)
- [Running Tests](#running-tests)
- [Test Categories](#test-categories)
- [CI/CD Pipeline](#cicd-pipeline)
- [Code Quality](#code-quality)
- [Contributing](#contributing)
---
## Prerequisites
- **Go 1.24+** (matches `go.mod`; CI runs Go 1.24.11)
- **OIDC Provider** credentials (Google, Azure, etc.) for any end-to-end test against a real provider
### Required Development Tools
```bash
# golangci-lint (comprehensive linting)
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# staticcheck (static analysis)
go install honnef.co/go/tools/cmd/staticcheck@latest
# gosec (security scanning)
go install github.com/securego/gosec/v2/cmd/gosec@latest
# govulncheck (vulnerability scanning)
go install golang.org/x/vuln/cmd/govulncheck@latest
```
---
## Local Development Setup
### Build and unit tests
```bash
go mod tidy
go build ./...
go test ./... -short # fast loop, < 30 s
go test -race -timeout=15m ./...
```
### Sample plugin configurations
Working middleware/Traefik configs live in [`examples/`](../examples/):
- `complete-traefik-config.yaml` — full middleware example
- `redis-config.yaml` — Redis cache configuration
To run the plugin against a real Traefik instance, drop the project on disk
and load it via `experimental.localPlugins` in your Traefik static config —
see the [README install section](../README.md#install).
### Integration tests
Integration tests live in `integration/`. Run them explicitly:
```bash
go test ./integration/... -run Integration -v
```
---
## Running Tests
### Quick Start
```bash
# Fast development testing (< 30 seconds)
go test ./... -short
# Standard tests with race detector
go test -race -timeout=15m ./...
# With coverage report
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
```
### Test Modes
| Mode | Command | Duration | Use Case |
|------|---------|----------|----------|
| Quick | `go test ./... -short` | < 30s | During development |
| Extended | `RUN_EXTENDED_TESTS=1 go test ./...` | 2-5 min | Before commits |
| Long | `RUN_LONG_TESTS=1 go test ./...` | 5-15 min | Release validation |
| Stress | `RUN_STRESS_TESTS=1 go test ./...` | 10-30 min | Performance testing |
### Environment Variables
```bash
# Enable specific test types
export RUN_EXTENDED_TESTS=1
export RUN_LONG_TESTS=1
export RUN_STRESS_TESTS=1
# Disable specific features
export DISABLE_LEAK_DETECTION=1
# Customize test parameters
export TEST_MAX_CONCURRENCY=10
export TEST_MAX_ITERATIONS=50
export TEST_MEMORY_THRESHOLD_MB=25.5
```
---
## Test Categories
### Quick Tests (Default)
- Basic functionality verification
- Limited iterations (1-3)
- Small data sets
- Essential memory leak checks
**Configuration:**
- Max Iterations: 3
- Max Concurrency: 5
- Memory Threshold: 2.0 MB
- Timeout: 10 seconds
### Extended Tests
- Comprehensive testing before commits
- More iterations (5-10)
- Enhanced memory leak detection
**Configuration:**
- Max Iterations: 10
- Max Concurrency: 20
- Memory Threshold: 10.0 MB
- Timeout: 30 seconds
### Long Tests
- Performance validation
- High iteration counts (50-100)
- Large data sets
**Configuration:**
- Max Iterations: 100
- Max Concurrency: 50
- Memory Threshold: 50.0 MB
- Timeout: 60 seconds
### Stress Tests
- Maximum load testing
- Edge case validation
- Extreme parameters
**Configuration:**
- Max Iterations: 500
- Max Concurrency: 100
- Memory Threshold: 100.0 MB
- Timeout: 120 seconds
### Running Specific Test Suites
```bash
# Memory leak tests
go test -v -run='.*Leak.*' ./...
# Integration tests
go test -v -run='.*Integration.*' ./...
# Regression tests
go test -v -run='.*Regression.*' ./...
# Provider-specific tests
go test -v -run='.*Azure.*' ./...
go test -v -run='.*Google.*' ./...
```
### Benchmarks
```bash
# Quick benchmarks
go test -bench=. -short
# Extended benchmarks
RUN_EXTENDED_TESTS=1 go test -bench=.
# Memory profiling
go test -bench=. -memprofile=mem.prof
go tool pprof mem.prof
```
---
## CI/CD Pipeline
The repository uses GitHub Actions for comprehensive validation with 20+ parallel checks.
### Triggered On
- Pull requests to `main` branch
- Pushes to `main` branch
### Parallel Jobs
#### Code Quality (3 checks)
- **Format & Basic Checks** - gofmt, go vet, go mod
- **golangci-lint** - 30+ linters
- **Staticcheck** - Advanced static analysis
#### Security (3 checks)
- **Gosec** - Security vulnerability scanning
- **Govulncheck** - Go vulnerability database
- **CodeQL** - GitHub's semantic code analysis
#### Testing (9 suites)
- Race Detector
- Coverage (70% threshold, enforced in `pr.yaml`)
- Memory Leaks
- Integration Tests
- Regression Tests
- Security Edge Cases
- Session Tests
- Token Tests
- CSRF Tests
#### Provider Testing (9 providers)
Tests run in parallel for:
- Google
- Azure AD
- Auth0
- Okta
- Keycloak
- AWS Cognito
- GitLab
- GitHub
- Generic OIDC
#### Performance & Build (3 checks)
- Benchmarks
- Multi-platform Build (linux/darwin x amd64/arm64)
- Go Version Compatibility (currently Go 1.24.11 in CI)
### Quality Gates
All PRs must pass:
- All parallel checks
- 70% test coverage minimum
- Zero security vulnerabilities
- No race conditions
- No memory leaks
- All providers tested
- Builds on all platforms
---
## Code Quality
### Pre-Commit Checklist
```bash
# Run before every commit
gofmt -s -w . && \
go mod tidy && \
golangci-lint run && \
go test -race -short ./... && \
echo "Ready to commit!"
```
### Local Validation
```bash
# Format code
gofmt -s -w .
# Run linter
golangci-lint run
# Static analysis
staticcheck ./...
# Security scan
gosec ./...
# Vulnerability check
govulncheck ./...
# Tests with race detector
go test -race -timeout=15m -count=1 ./...
# Coverage report
go test -coverprofile=coverage.out ./...
go tool cover -func=coverage.out
# View coverage in browser
go tool cover -html=coverage.out
```
### Troubleshooting
**Coverage Below Threshold:**
```bash
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out # See uncovered lines
```
**Race Condition Found:**
```bash
go test -race -v -run=TestName ./...
```
**Linter Errors:**
```bash
golangci-lint run -v
golangci-lint run --fix # Auto-fix some issues
```
**Provider Test Fails:**
```bash
go test -v -run='.*Azure.*' ./...
```
---
## Contributing
### Development Guidelines
1. **Memory Management**: Ensure all goroutines can be cancelled and resources are bounded
2. **Testing**: Add tests for new features, including memory leak tests where appropriate
3. **Race Conditions**: Run tests with `-race` flag to detect race conditions
4. **Documentation**: Update README and configuration files for new options
### Pull Request Template
PRs should include:
- Description of changes
- Type of change (bug fix, feature, breaking change, etc.)
- Related issues
- Provider impact (which providers are affected)
- Testing performed
- Security considerations
- Performance impact
- Breaking changes (if any)
### Checklist
Before submitting:
- [ ] Code follows project style
- [ ] Self-review completed
- [ ] Tests added for new functionality
- [ ] All tests pass locally
- [ ] Documentation updated
- [ ] No new warnings generated
### Code Owners
The repository uses CODEOWNERS for automatic PR reviewer assignment based on file paths.
### Dependabot
Automated dependency updates run weekly (Mondays 9 AM) with security updates prioritized.
---
## Additional Resources
- [golangci-lint Rules](.golangci.yml)
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
- [Workflow Documentation](.github/workflows/README.md)
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
+362
View File
@@ -0,0 +1,362 @@
# oidcgate — standalone OIDC forward-auth daemon
`oidcgate` is a single binary that exposes the same OIDC middleware that
powers the Traefik plugin as a forward-auth daemon for nginx, Caddy,
Traefik ForwardAuth, HAProxy, and Envoy `ext_authz_http`.
## Table of contents
- [Build](#build)
- [Run](#run)
- [Configuration](#configuration)
- [YAML file](#yaml-file)
- [Environment-variable overrides](#environment-variable-overrides)
- [Endpoints](#endpoints)
- [Reverse-proxy snippets](#reverse-proxy-snippets)
- [nginx (`auth_request`)](#nginx-auth_request)
- [Caddy (`forward_auth`)](#caddy-forward_auth)
- [Traefik (`ForwardAuth`)](#traefik-forwardauth)
- [HAProxy](#haproxy)
- [Envoy (`ext_authz_http`)](#envoy-ext_authz_http)
- [Security posture](#security-posture)
- [Bearer-token (M2M) auth on the same daemon](#bearer-token-m2m-auth-on-the-same-daemon)
- [Operational guidance](#operational-guidance)
- [Debugging](#debugging)
## Build
```bash
go build -o oidcgate ./cmd/oidcgate
```
## Run
```bash
./oidcgate --config /etc/oidcgate/config.yaml
```
The daemon parses `--config`, loads YAML, applies any `OIDCGATE_*` env-var
overrides, validates the result, and binds to `listen`. On SIGINT/SIGTERM it
calls `http.Server.Shutdown` with a 15s deadline, draining in-flight requests.
## Configuration
### YAML file
The OIDC subtree of the config maps 1:1 onto the [`traefikoidc.Config`](CONFIGURATION.md)
struct — every field documented under "Configuration Reference" works here
verbatim. Three extra top-level keys configure the daemon itself:
| Key | Default | Purpose |
|---|---|---|
| `listen` | _required_ | TCP address (e.g. `:8080`, `127.0.0.1:8080`). |
| `authPath` | `/oauth2/auth` | Silent-probe endpoint (used by nginx `auth_request`). |
| `startPath` | `/oauth2/start` | Visible sign-in endpoint. |
Minimal example (see [`examples/oidcgate.yaml`](../examples/oidcgate.yaml)):
```yaml
listen: ":8080"
providerURL: "https://accounts.google.com"
clientID: "your-client-id"
clientSecret: "your-client-secret"
sessionEncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
```
Nested structs (`redis:`, `securityHeaders:`, `dynamicClientRegistration:`)
round-trip cleanly through YAML — same shape as in `.traefik.yml`.
### Environment-variable overrides
Any of the following scalar fields can be overridden at runtime by an
`OIDCGATE_<UPPER_SNAKE_CASE>` environment variable. The env var wins over
the YAML value when set and non-empty. Intended for secret injection
(K8s `valueFrom.secretKeyRef`, systemd `EnvironmentFile=`, etc.).
| YAML key | Env var |
|---|---|
| `listen` | `OIDCGATE_LISTEN` |
| `authPath` | `OIDCGATE_AUTH_PATH` |
| `startPath` | `OIDCGATE_START_PATH` |
| `providerURL` | `OIDCGATE_PROVIDER_URL` |
| `clientID` | `OIDCGATE_CLIENT_ID` |
| `clientSecret` | `OIDCGATE_CLIENT_SECRET` |
| `audience` | `OIDCGATE_AUDIENCE` |
| `callbackURL` | `OIDCGATE_CALLBACK_URL` |
| `logoutURL` | `OIDCGATE_LOGOUT_URL` |
| `postLogoutRedirectURI` | `OIDCGATE_POST_LOGOUT_REDIRECT_URI` |
| `sessionEncryptionKey` | `OIDCGATE_SESSION_ENCRYPTION_KEY` |
| `cookiePrefix` | `OIDCGATE_COOKIE_PREFIX` |
| `cookieDomain` | `OIDCGATE_COOKIE_DOMAIN` |
| `logLevel` | `OIDCGATE_LOG_LEVEL` |
| `revocationURL` | `OIDCGATE_REVOCATION_URL` |
| `oidcEndSessionURL` | `OIDCGATE_OIDC_END_SESSION_URL` |
| `userIdentifierClaim` | `OIDCGATE_USER_IDENTIFIER_CLAIM` |
| `groupClaimName` | `OIDCGATE_GROUP_CLAIM_NAME` |
| `roleClaimName` | `OIDCGATE_ROLE_CLAIM_NAME` |
| `clientAuthMethod` | `OIDCGATE_CLIENT_AUTH_METHOD` |
| `clientAssertionPrivateKey` | `OIDCGATE_CLIENT_ASSERTION_PRIVATE_KEY` |
| `clientAssertionKeyPath` | `OIDCGATE_CLIENT_ASSERTION_KEY_PATH` |
| `clientAssertionKeyID` | `OIDCGATE_CLIENT_ASSERTION_KEY_ID` |
| `clientAssertionAlg` | `OIDCGATE_CLIENT_ASSERTION_ALG` |
| `caCertPath` | `OIDCGATE_CA_CERT_PATH` |
| `caCertPEM` | `OIDCGATE_CA_CERT_PEM` |
Nested-struct fields (Redis, security headers, DCR) are YAML-only — set
them in the config file, not via env.
## Endpoints
| Path | Method | Purpose |
|---|---|---|
| `/oauth2/auth` | GET | Silent probe — `200` if authenticated, `401` if not. Never returns `302`; the middleware's redirect-to-IdP is rewritten in-flight to `401` with the original `Location` carried as `X-Auth-Redirect`. |
| `/oauth2/start` | GET | Visible sign-in — `302` to the IdP authorize URL. Accepts `?rd=<safe-path>` (or honours `X-Forwarded-Uri`) for the post-login redirect target. |
| `/oauth2/callback` | GET | IdP `code`+`state` exchange. Path is configurable via `callbackURL`. |
| `/oauth2/logout` | GET/POST | Terminates the session. Path is configurable via `logoutURL`. Honours `oidcEndSessionURL` for RP-initiated logout. |
| `/healthz` | GET | Liveness — `200` while the process is alive. |
| `/readyz` | GET | Readiness — `200` once the OIDC discovery document has been fetched, otherwise `503`. |
## Reverse-proxy snippets
### nginx (`auth_request`)
```nginx
location = /oauth2/auth {
internal;
proxy_pass http://oidcgate:8080;
proxy_pass_request_body off;
proxy_set_header Content-Length "";
proxy_set_header X-Forwarded-Uri $request_uri;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Proto $scheme;
}
location @oidc_signin {
return 302 /oauth2/start?rd=$scheme://$host$request_uri;
}
location /oauth2/ {
proxy_pass http://oidcgate:8080;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Proto $scheme;
}
location / {
auth_request /oauth2/auth;
error_page 401 = @oidc_signin;
auth_request_set $user $upstream_http_x_forwarded_user;
auth_request_set $email $upstream_http_x_forwarded_email;
proxy_set_header X-Forwarded-User $user;
proxy_set_header X-Forwarded-Email $email;
proxy_pass http://backend;
}
```
### Caddy (`forward_auth`)
```caddyfile
example.com {
forward_auth oidcgate:8080 {
uri /oauth2/auth
copy_headers X-Forwarded-User X-Forwarded-Email
@denied status 401
handle_response @denied {
redir /oauth2/start?rd={http.request.uri} 302
}
}
handle /oauth2/* {
reverse_proxy oidcgate:8080
}
reverse_proxy backend:3000
}
```
### Traefik (`ForwardAuth`)
```yaml
http:
middlewares:
oidcgate:
forwardAuth:
address: "http://oidcgate:8080/oauth2/auth"
authResponseHeaders:
- X-Forwarded-User
- X-Forwarded-Email
```
Traefik can follow the `X-Auth-Redirect` value via a chained `redirectScheme`
middleware, or you can configure the upstream router to redirect `401`
`/oauth2/start` directly.
### HAProxy
```haproxy
frontend fe_https
bind *:443 ssl crt /etc/haproxy/certs/site.pem
http-request set-var(req.orig_uri) path
http-request send-spoe-group oidc auth-check # or use lua/SPOE; simplest is the lua snippet below
# The simpler pattern: dispatch /oauth2/* to oidcgate, everything else
# goes through a Lua filter that issues a sub-request to /oauth2/auth.
acl is_oidc_endpoint path_beg /oauth2/
use_backend be_oidcgate if is_oidc_endpoint
default_backend be_app
backend be_oidcgate
server oidcgate1 oidcgate:8080
backend be_app
server app1 backend:3000
```
HAProxy does not have a first-class `auth_request` equivalent in pure
config — the canonical patterns are SPOE (Stream Processing Offload Engine),
a Lua filter that issues `/oauth2/auth` and reads the response, or a
sidecar that does the dance. Reach for SPOE for high-throughput
production; Lua is simpler for low-volume.
### Envoy (`ext_authz_http`)
```yaml
http_filters:
- name: envoy.filters.http.ext_authz
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz
transport_api_version: V3
http_service:
server_uri:
uri: http://oidcgate:8080
cluster: oidcgate
timeout: 2s
path_prefix: /oauth2/auth
authorization_request:
allowed_headers:
patterns:
- exact: cookie
- exact: authorization
- prefix: x-forwarded-
authorization_response:
allowed_upstream_headers:
patterns:
- exact: x-forwarded-user
- exact: x-forwarded-email
allowed_client_headers:
patterns:
- exact: x-auth-redirect
- exact: set-cookie
```
On `401`, the `X-Auth-Redirect` header is surfaced to the downstream client
via `allowed_client_headers`. A small Envoy `router` filter or
`local_reply_config` rule can convert that into a browser-facing `302`
redirect to `/oauth2/start`.
## Security posture
- **`X-Forwarded-Uri` is sanitised.** The daemon forces
`TrustForwardedURI=true` so the middleware honours `X-Forwarded-Uri` for
the post-login redirect target. To prevent open redirects (CWE-601),
the value is rejected unless it is a safe same-origin path: must start
with `/`, must NOT start with `//` (protocol-relative), and must have
no scheme or host after parsing. Absolute URLs or anything that could
redirect off-origin falls through to `req.URL.RequestURI()`.
- **`excludedURLs` cannot bypass the daemon's own paths.** At config
load, the loader rejects any `excludedURLs` entry that is a prefix of
`authPath`, `startPath`, `callbackURL`, `logoutURL`, or the internal
sentinel path. A misconfiguration like `excludedURLs: ["/"]` (common
"allow all then add auth selectively" mistake) is rejected at startup
with a descriptive error.
- **`callbackURL` and `logoutURL` must be paths.** Absolute URLs are
rejected at config load — both because `http.ServeMux.Handle` panics
on non-`/` patterns and because the middleware's path-match would
silently fail.
- **`listen` is required.** Empty or missing `listen` is rejected at
startup rather than failing later at `net.Listen`.
- **Secrets via env vars.** `clientSecret` and `sessionEncryptionKey`
can be supplied via env vars instead of YAML so they don't end up on
disk if you use a secret manager.
## Bearer-token (M2M) auth on the same daemon
oidcgate uses the full `traefikoidc.Config` shape, so the bearer-token
M2M auth path documented in [`BEARER_AUTH.md`](BEARER_AUTH.md) works
out of the box. Add to your YAML:
```yaml
enableBearerAuth: true
audience: "https://api.example.com"
bearerIdentifierClaim: "sub"
# stripAuthorizationHeader: true # default
# bearerOverridesCookie: false # default — cookie wins on collision
```
With this set, the daemon accepts both:
- Browser users hitting `/oauth2/auth` → cookie session flow.
- API clients calling the protected backend with `Authorization: Bearer <jwt>`
→ bearer validation, principal headers, no session.
The bearer path doesn't go through `/oauth2/auth` separately — it's
applied by the middleware on every request the daemon sees, before the
cookie session check. See [BEARER_AUTH.md](BEARER_AUTH.md) for the full
threat model, identifier sanitisation rules, and failure-response
matrix.
## Operational guidance
- **Run behind a fronting proxy on a private network.** The daemon does
not terminate TLS. Put it on a localhost socket or a private subnet
reachable only from your nginx/Caddy/Traefik/HAProxy/Envoy.
- **`/healthz` and `/readyz` are unauthenticated** — correct for
Kubernetes liveness/readiness probes, but **do not expose them past a
load balancer**. Restrict via an ACL: nginx `allow 10.0.0.0/8; deny
all;`, Caddy `@health remote_ip 10.0.0.0/8`, k8s NetworkPolicy, or
your CNI of choice.
- **Multi-replica deployments** need a shared session store. Enable the
`redis:` block in the config (see [`docs/REDIS.md`](REDIS.md)) so
sessions survive a hop between replicas.
- **No built-in Prometheus metrics yet.** If you need request-level
visibility, take it from your fronting proxy's access logs — both
nginx and Envoy can tag `auth_request` / `ext_authz` outcomes.
- **Logs are minimal by default.** Set `logLevel: debug` while
bringing up a new deployment; raise to `info` (default) or higher
once stable. Debug logs include path-match decisions and metadata
refresh outcomes.
- **Graceful shutdown is 15s.** SIGINT or SIGTERM triggers
`http.Server.Shutdown(ctx)` with a 15-second deadline; in-flight
requests are allowed to complete. If your orchestrator's grace
period is shorter, requests can be cut mid-flight.
## Debugging
- **Requests appear as `/__oidcgate_protected__` in middleware debug
logs.** This is the internal sentinel path used when `/oauth2/auth`
and `/oauth2/start` delegate into the traefikoidc middleware. The
upstream client never sees it; it only shows up in the middleware's
own `Debugf` output when `logLevel: debug` is set.
- **`/oauth2/auth` returns `401` with `X-Auth-Redirect` header on
unauthenticated requests.** This is the deliberate translation of the
middleware's `302` to make nginx `auth_request` work. The browser is
redirected via the fronting proxy's `error_page 401 = @oidc_signin;`
pattern, not by following the daemon's response directly.
- **`/readyz` stays `503` after startup.** The middleware fetches the
OIDC discovery document lazily on first request, so `/readyz` returns
`503` until at least one request has triggered metadata discovery.
Hit `/oauth2/auth` once after startup to warm it up — many K8s
setups achieve the same effect because the liveness probe already
goes through the proxy chain.
- **Cookie/session diagnostics.** With `logLevel: debug` the middleware
logs which session manager was selected (in-memory vs Redis), whether
cookies decrypted successfully, and the JWT validation outcome.
- **Open-redirect rejections are silent.** When the daemon ignores an
unsafe `X-Forwarded-Uri` value, it falls back to `req.URL.RequestURI()`
without logging. This is intentional (no recon signal) — if a user
reports "I keep landing on the wrong page after login", inspect
whether the upstream proxy is forwarding a non-canonical
`X-Forwarded-Uri` value.
+582
View File
@@ -0,0 +1,582 @@
# OIDC Provider Configuration Guide
Configuration reference for each supported OIDC provider.
## Table of Contents
- [Provider Support Matrix](#provider-support-matrix)
- [Google](#google)
- [Microsoft Azure AD](#microsoft-azure-ad)
- [Auth0](#auth0)
- [Okta](#okta)
- [Keycloak](#keycloak)
- [AWS Cognito](#aws-cognito)
- [GitLab](#gitlab)
- [GitHub](#github)
- [Generic OIDC](#generic-oidc)
- [Automatic Scope Filtering](#automatic-scope-filtering)
---
## Provider Support Matrix
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|----------|-------------|----------------|----------------|-----------|
| Google | Full | Yes | `accounts.google.com` | Yes |
| Azure AD | Full | Yes | `login.microsoftonline.com`, `sts.windows.net` | Yes |
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
| Okta | Full | Yes | `*.okta.com`, `*.oktapreview.com`, `*.okta-emea.com` | Yes |
| Keycloak | Full | Yes | host containing `keycloak`, or `/realms/` in path (matches both `/auth/realms/` legacy and `/realms/` modern) | Yes |
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
| GitLab | Full | Yes | `gitlab.com` | Yes |
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
| Generic | Full | Yes | Any OIDC endpoint | Yes |
---
## Google
### Provider URL
```yaml
providerURL: "https://accounts.google.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
spec:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-id.apps.googleusercontent.com"
clientSecret: "your-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- email
- profile
allowedUserDomains:
- "your-gsuite-domain.com" # Optional: Workspace restriction
forceHttps: true
enablePkce: true
```
### Google-Specific Features
- **Automatic offline access**: Middleware adds `access_type=offline` and `prompt=consent`
- **Scope filtering**: Automatically removes unsupported `offline_access` scope
- **Workspace domains**: Restrict to specific Google Workspace domains via `hd` claim
### Google Cloud Console Setup
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
2. Create or select a project
3. Navigate to APIs & Services > Credentials
4. Create OAuth 2.0 Client ID (Web application)
5. Add authorized redirect URI: `https://your-domain.com/oauth2/callback`
6. Configure OAuth consent screen (must be "Published" for production)
---
## Microsoft Azure AD
### Provider URL
```yaml
# Single tenant
providerURL: "https://login.microsoftonline.com/{tenant-id}/v2.0"
# Multi-tenant
providerURL: "https://login.microsoftonline.com/common/v2.0"
```
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure
spec:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/common/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- offline_access
allowedRolesAndGroups:
- "App.Users"
- "Admin.Group"
forceHttps: true
```
### With Application ID URI (API Access)
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure-api
spec:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/common/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
audience: "api://your-azure-client-id" # Application ID URI
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
forceHttps: true
```
### Users Without Email
```yaml
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
allowedUsers:
- "user-object-id-1"
- "user-object-id-2"
```
### Azure AD Setup
1. Go to [Azure Portal](https://portal.azure.com/)
2. Navigate to Azure Active Directory > App registrations
3. Create new registration
4. Add redirect URI: `https://your-domain.com/oauth2/callback`
5. Create client secret in Certificates & secrets
6. Configure Token Configuration for group claims
---
## Auth0
### Provider URL
```yaml
providerURL: "https://your-domain.auth0.com"
```
### Basic Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.auth0.com"
clientID: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- offline_access
postLogoutRedirectUri: "https://your-app.com"
forceHttps: true
enablePkce: true
```
### With Custom API Audience
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth0-api
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.auth0.com"
clientID: "your-auth0-client-id"
clientSecret: "your-auth0-client-secret"
audience: "https://api.your-domain.com" # API identifier
strictAudienceValidation: true
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
roleClaimName: "https://your-app.com/roles" # Namespaced claim
groupClaimName: "https://your-app.com/groups"
allowedRolesAndGroups:
- admin
- editor
```
### Auth0 Action for Custom Claims
```javascript
exports.onExecutePostLogin = async (event, api) => {
const namespace = 'https://your-app.com/';
if (event.authorization) {
api.idToken.setCustomClaim(namespace + 'roles', event.authorization.roles);
api.idToken.setCustomClaim('email', event.user.email);
}
};
```
### Auth0 Setup
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
2. Create Regular Web Application
3. Configure Allowed Callback URLs: `https://your-domain.com/oauth2/callback`
4. Configure Allowed Logout URLs: `https://your-domain.com/oauth2/logout`
5. Enable OIDC Conformant in Advanced Settings
6. Create API in APIs section for custom audiences
See [AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md) for detailed audience configuration.
---
## Okta
### Provider URL
```yaml
providerURL: "https://your-domain.okta.com"
# Or with custom authorization server:
providerURL: "https://your-domain.okta.com/oauth2/default"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-okta
spec:
plugin:
traefikoidc:
providerURL: "https://your-domain.okta.com"
clientID: "your-okta-client-id"
clientSecret: "your-okta-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- groups
- offline_access
allowedRolesAndGroups:
- admin
- "Everyone"
forceHttps: true
enablePkce: true
```
### Okta Setup
1. Access Okta Admin Console
2. Go to Applications > Create App Integration
3. Select OIDC - OpenID Connect > Web Application
4. Configure Sign-in redirect URIs: `https://your-domain.com/oauth2/callback`
5. Configure Sign-out redirect URIs: `https://your-domain.com/oauth2/logout`
6. Enable Authorization Code and Refresh Token grant types
7. Configure Groups claim in authorization server
---
## Keycloak
### Provider URL
```yaml
providerURL: "https://keycloak.your-domain.com/realms/{realm-name}"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-keycloak
spec:
plugin:
traefikoidc:
providerURL: "https://keycloak.company.com/realms/your-realm"
clientID: "your-keycloak-client-id"
clientSecret: "your-keycloak-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- roles
- groups
- offline_access
allowedRolesAndGroups:
- admin
- editor
forceHttps: true
enablePkce: true
```
### Internal Network Deployment
For private IP addresses (Docker networks, Kubernetes):
```yaml
providerURL: "https://192.168.1.100:8443/realms/your-realm"
allowPrivateIPAddresses: true # Required for private IPs
```
### Keycloak Client Setup
1. Access Keycloak Admin Console
2. Select your realm
3. Go to Clients > Create client
4. Set Client Protocol: openid-connect
5. Set Access Type: confidential
6. Add Valid Redirect URIs: `https://your-domain.com/oauth2/callback`
7. Generate client secret in Credentials tab
8. Configure mappers to add claims to ID Token:
- Email: User Property mapper with "Add to ID token" enabled
- Roles: User Client Role mapper with "Add to ID token" enabled
- Groups: Group Membership mapper with "Add to ID token" enabled
See [KEYCLOAK_SETUP_GUIDE.md](KEYCLOAK_SETUP_GUIDE.md) for detailed step-by-step setup instructions, mapper configuration, troubleshooting, and performance optimization.
---
## AWS Cognito
### Provider URL
```yaml
providerURL: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-cognito
spec:
plugin:
traefikoidc:
providerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
clientID: "your-cognito-client-id"
clientSecret: "your-cognito-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
- aws.cognito.signin.user.admin
allowedRolesAndGroups:
- admin
- users
forceHttps: true
```
### AWS Cognito Setup
1. Create Cognito User Pool
2. Create App Client with OIDC scopes
3. Configure App Client settings:
- Callback URLs: `https://your-domain.com/oauth2/callback`
- Sign out URLs: `https://your-domain.com/oauth2/logout`
- OAuth flows: Authorization code grant
4. Configure hosted UI domain (optional)
5. Set up groups for role-based access
---
## GitLab
### Provider URL
```yaml
# GitLab.com
providerURL: "https://gitlab.com"
# Self-hosted
providerURL: "https://gitlab.your-company.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-gitlab
spec:
plugin:
traefikoidc:
providerURL: "https://gitlab.com"
clientID: "your-gitlab-application-id"
clientSecret: "your-gitlab-application-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
# Note: GitLab doesn't require offline_access scope
# Refresh tokens are issued automatically with openid
allowedRolesAndGroups:
- developers
- maintainers
forceHttps: true
enablePkce: true
```
### GitLab Setup
1. Go to GitLab Settings > Applications
2. Create new application
3. Add scopes: `openid`, `profile`, `email`
4. Set redirect URI: `https://your-domain.com/oauth2/callback`
5. Save and note Application ID and Secret
---
## GitHub
### Provider URL
```yaml
providerURL: "https://github.com"
```
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oauth-github
spec:
plugin:
traefikoidc:
providerURL: "https://github.com/login/oauth"
clientID: "your-github-client-id"
clientSecret: "your-github-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- user:email
- read:user
allowedUsers:
- "github-username"
forceHttps: true
```
### Limitations
- **OAuth 2.0 only** - Not OpenID Connect
- **No ID tokens** - Only access tokens for API calls
- **No refresh tokens** - Users must re-authenticate on expiry
- **No standard claims** - User info requires API calls
Use GitHub only for API access, not for user authentication with claims.
### GitHub Setup
1. Go to GitHub Settings > Developer settings > OAuth Apps
2. Create new OAuth App
3. Set Authorization callback URL: `https://your-domain.com/oauth2/callback`
4. Note Client ID and generate Client Secret
---
## Generic OIDC
For any OIDC-compliant provider not listed above.
### Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-generic
spec:
plugin:
traefikoidc:
providerURL: "https://oidc.your-provider.com"
clientID: "your-client-id"
clientSecret: "your-client-secret"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-32-char-encryption-key-here"
scopes:
- openid
- profile
- email
forceHttps: true
enablePkce: true
```
### Requirements
- Provider must expose `.well-known/openid-configuration` endpoint
- Must support authorization code flow
- ID tokens must contain required claims (email, sub, etc.)
---
## Automatic Scope Filtering
The middleware automatically filters OAuth scopes based on the provider's declared capabilities.
### How It Works
1. Fetches provider's `.well-known/openid-configuration`
2. Extracts `scopes_supported` field
3. Filters requested scopes to only include supported ones
4. Falls back to all requested scopes if provider doesn't declare supported scopes
### Example: Self-Hosted GitLab
Self-hosted GitLab may reject `offline_access` scope:
```yaml
scopes:
- openid
- profile
- email
- offline_access # Will be automatically filtered out if unsupported
```
The middleware will:
1. Read GitLab's discovery document
2. Detect `offline_access` is NOT in `scopes_supported`
3. Filter it out automatically
4. Authentication succeeds
### Logging
```
INFO: ScopeFilter: Filtered unsupported scopes: [offline_access]
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
```
### Troubleshooting
If a provider rejects scopes even after filtering:
1. Check the provider's discovery document: `curl https://provider/.well-known/openid-configuration`
2. Use `overrideScopes: true` with only supported scopes
3. Review middleware debug logs for filtering decisions
+554
View File
@@ -0,0 +1,554 @@
# Redis Cache for Distributed Deployments
Redis cache support for multi-replica Traefik deployments with shared state.
## Table of Contents
- [Overview](#overview)
- [Why Use Redis Cache?](#why-use-redis-cache)
- [Configuration](#configuration)
- [Cache Modes](#cache-modes)
- [Deployment Examples](#deployment-examples)
- [Performance Tuning](#performance-tuning)
- [Monitoring](#monitoring)
- [Troubleshooting](#troubleshooting)
- [Migration Guide](#migration-guide)
---
## Overview
The Redis cache feature provides distributed caching for the Traefik OIDC plugin, enabling seamless operation across multiple Traefik instances.
### Key Features
- **Distributed JTI Replay Detection**: Prevents token replay attacks across all instances
- **Shared Session Management**: Consistent user sessions across replicas
- **Circuit Breaker**: Automatic fallback to memory cache during Redis outages
- **Health Checking**: Continuous monitoring of Redis connectivity
- **Flexible Cache Modes**: Memory, Redis, or hybrid caching strategies
- **Pure-Go Implementation**: Yaegi-compatible, works with dynamic plugin loading
### Architecture
```
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Traefik #1 │ │ Traefik #2 │ │ Traefik #3
│ (Plugin) │ │ (Plugin) │ │ (Plugin) │
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
│ │ │
└────────────────────┼────────────────────┘
┌──────▼──────┐
│ Redis │
│ (Shared │
│ Cache) │
└─────────────┘
```
---
## Why Use Redis Cache?
### The Problem
When running multiple Traefik instances without shared cache:
1. **False Positive Replay Detection**
- User authenticates → Token stored in Instance A's JTI cache
- Next request → Load balancer routes to Instance B
- Instance B doesn't have the JTI → Falsely detects replay attack
2. **Session Inconsistency**
- User session created on Instance A
- Subsequent request routed to Instance B
- Instance B has no knowledge of the session
3. **Token Metadata Fragmentation**
- Token refresh happens on Instance A
- Other instances continue using old tokens
### The Solution
Redis provides centralized cache that all instances share, ensuring:
- **Consistent Authentication**: All instances share authentication state
- **True Replay Detection**: JTI cache shared across all instances
- **Seamless Scaling**: Add/remove instances without affecting sessions
- **High Availability**: Circuit breaker with automatic fallback
---
## Configuration
### Basic Configuration
```yaml
redis:
enabled: true
address: "redis:6379"
password: "your-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
cacheMode: "hybrid"
```
### All Configuration Options
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enabled` | bool | `false` | Enable Redis caching |
| `address` | string | - | Redis server address (`host:port`) |
| `password` | string | - | Redis password (optional) |
| `db` | int | `0` | Redis database number (0-15) |
| `keyPrefix` | string | `traefikoidc:` | Prefix for all Redis keys |
| `cacheMode` | string | `redis` | Cache mode: `memory`, `redis`, `hybrid` |
| `poolSize` | int | `10` | Connection pool size |
| `connectTimeout` | int | `5` | Connection timeout (seconds) |
| `readTimeout` | int | `3` | Read timeout (seconds) |
| `writeTimeout` | int | `3` | Write timeout (seconds) |
| `enableTLS` | bool | `false` | Enable TLS for connections |
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
| `enableCircuitBreaker` | bool | `false` | Wrap the Redis backend with a circuit breaker. **Recommended `true` in production.** |
| `circuitBreakerThreshold` | int | `5` | Consecutive failures before the circuit opens (only when `enableCircuitBreaker: true`). |
| `circuitBreakerTimeout` | int | `60` | Seconds the circuit stays open before allowing a probe (only when `enableCircuitBreaker: true`). |
| `enableHealthCheck` | bool | `false` | Wrap the Redis backend with periodic health checks. **Recommended `true` in production.** |
| `healthCheckInterval` | int | `30` | Health check interval in seconds (only when `enableHealthCheck: true`). |
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
### Environment Variables (Fallback)
If not configured through Traefik, these environment variables are used:
```bash
REDIS_ENABLED=true
REDIS_ADDRESS=redis:6379
REDIS_PASSWORD=your-password
REDIS_DB=0
REDIS_KEY_PREFIX=traefikoidc:
REDIS_CACHE_MODE=hybrid
REDIS_POOL_SIZE=10
REDIS_CONNECT_TIMEOUT=5
REDIS_READ_TIMEOUT=3
REDIS_WRITE_TIMEOUT=3
REDIS_ENABLE_TLS=false
REDIS_TLS_SKIP_VERIFY=false
REDIS_HYBRID_L1_SIZE=500
REDIS_HYBRID_L1_MEMORY_MB=10
```
> Resilience fields (`enableCircuitBreaker`, `enableHealthCheck`,
> `circuitBreakerThreshold`, `circuitBreakerTimeout`, `healthCheckInterval`)
> have no environment variable fallback — set them in plugin configuration.
Invalid `cacheMode` values are rejected at plugin startup.
---
## Cache Modes
### Memory Mode (used when Redis is disabled)
```yaml
redis:
cacheMode: "memory"
```
- Uses only in-memory cache
- Suitable for single-instance deployments
- No Redis dependency
- Fastest performance
### Redis Mode
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "redis"
```
- All operations go directly to Redis
- Ensures consistency across replicas
- Slightly higher latency
### Hybrid Mode (Recommended)
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
Two-tier caching strategy:
```
┌─────────────────────────────────────────┐
│ Client Request │
└────────────────┬────────────────────────┘
┌────────────────┐
│ Local Cache │ ← L1 Cache (Fast)
│ (Memory) │
└────────┬───────┘
│ Miss
┌────────────────┐
│ Remote Cache │ ← L2 Cache (Shared)
│ (Redis) │
└────────────────┘
```
**Read Path:**
1. Check local memory cache (L1)
2. On miss, check Redis (L2)
3. On hit in Redis, populate L1
4. Return value
**Write Path:**
1. Write to Redis (L2) for durability
2. Write to local cache (L1) for speed
### Performance Comparison
| Operation | Memory Mode | Redis Mode | Hybrid Mode |
|-----------|------------|------------|-------------|
| Read (p50) | 0.1ms | 2ms | 0.2ms |
| Read (p99) | 0.5ms | 10ms | 5ms |
| Write (p50) | 0.2ms | 3ms | 3ms |
| Throughput | 100k/s | 20k/s | 80k/s |
---
## Deployment Examples
### Docker Compose
```yaml
version: '3.8'
services:
redis:
image: redis:7-alpine
command: redis-server --requirepass ${REDIS_PASSWORD}
volumes:
- redis-data:/data
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 30s
timeout: 3s
retries: 3
traefik:
image: traefik:v3.2
deploy:
replicas: 3
labels:
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=${REDIS_PASSWORD}"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
depends_on:
redis:
condition: service_healthy
volumes:
redis-data:
```
### Kubernetes
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-redis
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: your-encryption-key
callbackURL: /oauth2/callback
redis:
enabled: true
address: "redis-service.redis-namespace:6379"
password: "urn:k8s:secret:redis-secret:password"
db: 0
keyPrefix: "traefikoidc:"
cacheMode: "hybrid"
poolSize: 20
enableCircuitBreaker: true
circuitBreakerThreshold: 5
```
### AWS ElastiCache
```yaml
redis:
enabled: true
address: "your-cache.abc123.cache.amazonaws.com:6379"
cacheMode: "hybrid"
enableTLS: true
password: "your-elasticache-auth-token"
```
---
## Performance Tuning
### Connection Pool Sizing
```yaml
redis:
poolSize: 20 # Formula: 2 * CPU cores * replicas
# For 4 cores, 3 replicas: poolSize = 24
```
### TTL Strategy
The plugin automatically sets TTLs based on token lifetimes:
- **JTI Cache**: Matches token lifetime (typically 1 hour)
- **Session**: Matches `sessionMaxAge` configuration
- **Token Metadata**: 5 minutes (short-lived)
### Redis Server Configuration
```bash
# Recommended Redis settings for cache
maxmemory 512mb
maxmemory-policy allkeys-lru # Evict least recently used
# For cache data, disable persistence for better performance
save ""
appendonly no
```
### Hybrid Mode Tuning
```yaml
redis:
cacheMode: "hybrid"
hybridL1Size: 500 # Max items in local cache
hybridL1MemoryMB: 10 # Max memory for local cache
```
---
## Monitoring
### Key Metrics
- **Cache hit rate** (target: >90% for hybrid mode)
- **Redis latency** (target: <10ms p99)
- **Circuit breaker state**
- **Connection pool utilization
### Redis Commands for Monitoring
```bash
# Monitor commands in real-time
redis-cli MONITOR
# Check slow queries
redis-cli SLOWLOG GET 10
# Memory usage
redis-cli INFO memory
# Key statistics
redis-cli DBSIZE
# List keys with prefix
redis-cli --scan --pattern "traefikoidc:*"
# Check key TTL
redis-cli TTL "traefikoidc:session:abc123"
```
### Health Check Endpoint
The plugin provides health information including:
```json
{
"status": "healthy",
"cache": {
"mode": "hybrid",
"redis": {
"connected": true,
"latency": "2ms"
},
"circuit_breaker": {
"state": "closed",
"failures": 0
}
}
}
```
---
## Troubleshooting
### Connection Refused
**Symptoms:** `dial tcp: connection refused`
**Solutions:**
1. Verify Redis is running: `redis-cli ping`
2. Check network connectivity: `telnet redis-host 6379`
3. Verify address configuration
### Authentication Failure
**Symptoms:** `NOAUTH Authentication required`
**Solutions:**
1. Set Redis password in configuration
2. Verify password is correct
### Circuit Breaker Open
**Symptoms:** `Circuit breaker is open`, falling back to memory
**Solutions:**
1. Check Redis health: `redis-cli INFO server`
2. Review network latency: `redis-cli --latency`
3. Adjust circuit breaker thresholds if needed
### High Memory Usage
**Symptoms:** Redis memory constantly growing, OOM errors
**Solutions:**
1. Configure eviction policy:
```bash
CONFIG SET maxmemory 512mb
CONFIG SET maxmemory-policy allkeys-lru
```
2. Review key count: `redis-cli DBSIZE`
3. Check for large keys: `redis-cli --bigkeys`
### Inconsistent Cache State
**Symptoms:** Different responses from different replicas
**Solutions:**
1. Verify all instances use the same Redis address
2. Check cache mode consistency across instances
3. Verify time synchronization on all hosts
---
## Migration Guide
### From Memory-Only to Redis
#### Phase 1: Preparation
1. Deploy Redis infrastructure
2. Test Redis connectivity
3. Configure monitoring
#### Phase 2: Gradual Rollout
1. Enable Redis on one instance:
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
```
2. Monitor for errors
3. Gradually enable on more instances
#### Phase 3: Full Migration
1. Enable Redis on all instances
2. Remove `disableReplayDetection: true` if set
3. Monitor for issues
### Rollback Plan
If issues occur:
1. Set `redis.enabled: false`
2. Plugin falls back to memory cache automatically
3. Investigate and resolve issues
### Migration Checklist
- [ ] Redis deployed and accessible
- [ ] Redis password configured
- [ ] Network connectivity verified
- [ ] Monitoring configured
- [ ] Backup plan prepared
- [ ] Test environment validated
- [ ] Gradual rollout planned
---
## Best Practices
### Security
- Always use Redis password authentication
- Enable TLS for production deployments
- Use network segmentation (private subnets)
- Rotate Redis passwords regularly
### High Availability
- Use Redis Sentinel or Cluster for HA
- Configure appropriate circuit breaker thresholds
- Implement proper health checks
- Use connection pooling
### Performance
- Use hybrid cache mode for best performance
- Monitor cache hit rates
- Size Redis memory appropriately
- Disable persistence for cache-only usage
### Operations
- Implement comprehensive monitoring
- Set up alerting for circuit breaker state
- Document Redis configuration
- Test failover scenarios
---
## FAQ
### Is Redis required?
No, Redis is optional. The plugin works with in-memory cache for single-instance deployments.
### What happens if Redis goes down?
The circuit breaker opens after threshold failures, and the plugin falls back to in-memory cache. It periodically attempts to reconnect.
### Which cache mode should I use?
For production multi-replica deployments, use `hybrid` mode for best performance and consistency.
### How much memory does Redis need?
Depends on active sessions and token sizes:
- Small (1-1000 users): 128MB
- Medium (1000-10000 users): 256-512MB
- Large (10000+ users): 1GB+
### Can I use managed Redis services?
Yes, the plugin works with AWS ElastiCache, Azure Cache for Redis, Google Cloud Memorystore, and Redis Enterprise Cloud.
### Is data encrypted in Redis?
Session data is encrypted before storing using `sessionEncryptionKey`. Additionally, you can enable TLS for Redis connections.
+390
View File
@@ -0,0 +1,390 @@
# Testing Guide
Comprehensive testing infrastructure for traefikoidc.
## Overview
| Metric | Value |
|--------|-------|
| Test files | 110 |
| Lines of test code | ~72,000 |
| Code coverage | 71.0% |
| Race conditions | None (all pass with `-race`) |
## Running Tests
```bash
# Run all tests
go test ./...
# Run with race detection
go test -race ./...
# Run with coverage
go test -cover ./...
# Run specific test suite
go test -v -run "TokenValidationSuite" .
# Run edge case tests
go test -v -run "ClockSkewEdgeCasesSuite|UnicodeClaimsSuite" .
```
## Test Infrastructure
### Directory Structure
```
internal/testutil/
├── compat.go # Re-exports for main package access
├── mocks/
│ ├── interfaces.go # JWKCache, TokenExchanger, TokenVerifier, etc.
│ ├── session.go # SessionManager, SessionData
│ ├── cache.go # Cache, TokenCache, Blacklist
│ └── interfaces_test.go # Mock verification tests
├── fixtures/
│ └── tokens.go # JWT token generation fixtures
└── servers/
├── oidc.go # Mock OIDC server factory
└── oidc_test.go # Server tests
```
### Test Suites
| Suite | File | Description |
|-------|------|-------------|
| TokenValidationSuite | `token_validation_suite_test.go` | Token validation happy path and error cases |
| JWKCacheTestSuite | `token_validation_suite_test.go` | JWK cache behavior tests |
| TokenExchangerTestSuite | `token_validation_suite_test.go` | Token exchange scenarios |
| ClockSkewEdgeCasesSuite | `edge_cases_suite_test.go` | Expiry boundary testing |
| UnicodeClaimsSuite | `edge_cases_suite_test.go` | Unicode/emoji handling in claims |
| LargeClaimsSuite | `edge_cases_suite_test.go` | Large data handling (100s of claims) |
| URLPathEdgeCasesSuite | `edge_cases_suite_test.go` | URL parsing edge cases |
| ConcurrencyEdgeCasesSuite | `edge_cases_suite_test.go` | Concurrent token validation |
| ExampleTestSuite | `testutil_example_test.go` | Example demonstrating patterns |
| AuthFlowBehaviourSuite | `auth_flow_behaviour_test.go` | Authentication flow behavior tests |
| SessionBehaviourSuite | `session_behaviour_test.go` | Session management behavior tests |
| EnhancedMocksSuite | `enhanced_mocks_suite_test.go` | Enhanced mock usage demonstration |
## Mock Types
The project provides two mocking patterns:
### State-Based Mocks (Basic)
Located in `main_test.go`, `mocks_test.go`. Simple mocks that store data in struct fields.
| Mock | Interface | Description |
|------|-----------|-------------|
| `MockJWKCache` | `JWKCacheInterface` | Simple state-based mock with JWKS/Err fields |
| `MockTokenVerifier` | `TokenVerifier` | Function-based mock for token verification |
| `MockTokenExchanger` | `TokenExchanger` | Function-based mock for token exchange |
| `MockOAuthProvider` | `http.Handler` | Full HTTP handler mock for OAuth provider simulation |
| `MockSessionManager` | `SessionManager` | State-based mock for session management |
| `MockHTTPClient` | N/A | Mock HTTP client with customizable responses |
**Usage:**
```go
mock := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tOidc := &TraefikOidc{
jwkCache: mock,
// ...
}
```
### Enhanced State-Based Mocks (with Call Tracking)
Located in `enhanced_mocks_test.go`. State-based mocks with built-in call tracking and assertion helpers.
| Mock | Interface | Description |
|------|-----------|-------------|
| `EnhancedMockJWKCache` | `JWKCacheInterface` | State-based with call tracking |
| `EnhancedMockTokenVerifier` | `TokenVerifier` | State-based with call tracking |
| `EnhancedMockTokenExchanger` | `TokenExchanger` | State-based with call tracking |
| `EnhancedMockCacheInterface` | `CacheInterface` | Functional cache with call tracking |
**Usage:**
```go
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
}
// Make calls
result, err := mock.GetJWKS(ctx, "https://example.com/jwks", nil)
// Verify calls were made
mock.AssertGetJWKSCalled(t)
mock.AssertGetJWKSCalledWith(t, "https://example.com/jwks")
mock.AssertGetJWKSCallCount(t, 1)
// Access call details
s.Equal(1, mock.GetJWKSCallCount())
```
**Features:**
- Track all calls with parameters and timestamps
- Built-in assertion helpers using testify
- Thread-safe for concurrent tests
- `Reset()` method to clear state between tests
- `LastCall()` to inspect most recent call
### Testify-Based Mocks
Located in `testify_mocks_test.go`. Mocks using testify's `.On()/.Return()` pattern for behavior verification.
| Mock | Interface | Description |
|------|-----------|-------------|
| `TestifyJWKCache` | `JWKCacheInterface` | Testify mock with `.On()/.Return()` |
| `TestifyTokenVerifier` | `TokenVerifier` | Testify mock for token verification |
| `TestifyTokenExchanger` | `TokenExchanger` | Testify mock for token exchange |
| `TestifyCacheInterface` | `CacheInterface` | Testify mock for cache operations |
| `TestifyHTTPClient` | N/A | Testify mock for HTTP client |
| `TestifyRoundTripper` | `http.RoundTripper` | Testify mock for HTTP transport |
**Usage:**
```go
mock := &TestifyJWKCache{}
mock.On("GetJWKS", mock.Anything, "https://example.com/jwks", mock.Anything).
Return(&JWKSet{Keys: []JWK{jwk}}, nil)
// After test
mock.AssertExpectations(t)
```
### Testutil Package Mocks
Located in `internal/testutil/mocks/`. Generic mocks for testing the test infrastructure itself.
```go
import "github.com/lukaszraczylo/traefikoidc/internal/testutil"
mock := testutil.NewJWKCacheMock()
mock.On("GetJWKS", mock.Anything, mock.Anything, mock.Anything).
Return(&mocks.JWKSet{Keys: []mocks.JWK{{Kty: "RSA"}}}, nil)
```
### Choosing the Right Mock
| Use Case | Recommended Mock |
|----------|-----------------|
| Simple return values only | Basic state-based (`MockJWKCache`) |
| Return values + verify calls made | Enhanced state-based (`EnhancedMockJWKCache`) |
| Complex call expectations | Testify-based (`TestifyJWKCache`) |
| Verify call order/sequence | Testify-based |
| HTTP endpoint simulation | `MockOAuthProvider` |
| New testify suite tests | Enhanced or Testify-based |
**Decision Guide:**
1. **Basic State-Based**: Use when you only need to control return values and don't care about verifying interactions.
2. **Enhanced State-Based**: Use when you want to verify calls were made with specific parameters, but prefer simpler setup than testify's `.On()/.Return()` pattern.
3. **Testify-Based**: Use when you need complex behavior like different returns per call, strict call ordering, or detailed expectation matching.
## Token Fixtures
The `testutil.TokenFixture` generates JWT tokens for testing:
```go
fixture, err := testutil.NewTokenFixture()
// Valid token with default claims
token, _ := fixture.ValidToken(nil)
// Token with custom claims
token, _ := fixture.ValidToken(map[string]interface{}{
"email": "test@example.com",
"roles": []string{"admin"},
})
// Expired token
token, _ := fixture.ExpiredToken()
// Token with specific roles/groups
token, _ := fixture.TokenWithRoles([]string{"admin", "user"})
token, _ := fixture.TokenWithGroups([]string{"developers"})
// Token with clock skew
token, _ := fixture.TokenWithSkew(-2 * time.Minute) // expired 2 min ago
token, _ := fixture.TokenWithSkew(5 * time.Minute) // expires in 5 min
// Token missing specific claims
token, _ := fixture.TokenMissingClaim("email", "sub")
// Malformed token
token := fixture.MalformedToken() // "not.a.valid.jwt"
// Get JWKS for verification
jwks := fixture.GetJWKS()
```
## Mock OIDC Server
The `testutil.OIDCServer` provides a fully functional mock OIDC provider:
```go
// Default configuration
server := testutil.NewOIDCServer(nil)
defer server.Close()
// Custom configuration
config := testutil.DefaultServerConfig()
config.Issuer = "https://custom-issuer.com"
config.TokenError = &testutil.OIDCError{
Error: "invalid_grant",
Description: "Authorization code expired",
}
server := testutil.NewOIDCServer(config)
// Provider-specific configurations
googleConfig := testutil.GoogleServerConfig()
azureConfig := testutil.AzureServerConfig()
auth0Config := testutil.Auth0ServerConfig()
keycloakConfig := testutil.KeycloakServerConfig()
// Behavior configurations
slowConfig := testutil.SlowServerConfig(100 * time.Millisecond)
rateLimitedConfig := testutil.RateLimitedServerConfig(5) // Limit after 5 requests
```
### Server Endpoints
| Endpoint | Description |
|----------|-------------|
| `/.well-known/openid-configuration` | OIDC discovery document |
| `/authorize` | Authorization endpoint |
| `/token` | Token exchange endpoint |
| `/jwks` | JSON Web Key Set |
| `/userinfo` | User information endpoint |
| `/introspect` | Token introspection |
| `/revoke` | Token revocation |
| `/logout` | End session endpoint |
### Request Tracking
```go
server := testutil.NewOIDCServer(nil)
// Make requests...
count := server.GetRequestCount()
requests := server.GetRequests()
server.Reset() // Clear tracking
```
## Writing Test Suites
### Basic Suite Structure
```go
type MyTestSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *MyTestSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *MyTestSuite) SetupTest() {
// Per-test setup
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
// ...
}
}
func (s *MyTestSuite) TearDownTest() {
// Per-test cleanup
}
func (s *MyTestSuite) TestSomething() {
token, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err)
}
func TestMyTestSuite(t *testing.T) {
suite.Run(t, new(MyTestSuite))
}
```
### Table-Driven Tests
```go
func (s *MyTestSuite) TestClockSkewEdgeCases() {
testCases := []struct {
name string
skew time.Duration
shouldPass bool
}{
{"valid_token", 5 * time.Minute, true},
{"expired_within_tolerance", -1 * time.Minute, true},
{"expired_beyond_tolerance", -10 * time.Minute, false},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
token, err := s.fixture.TokenWithSkew(tc.skew)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
if tc.shouldPass {
s.NoError(err)
} else {
s.Error(err)
}
})
}
}
```
## Test Categories
### Happy Path Tests
Test the expected successful scenarios:
- Valid token verification
- Successful token exchange
- Session creation and retrieval
- Cache operations
### Error Case Tests
Test failure scenarios:
- Expired tokens
- Invalid signatures
- Wrong issuer/audience
- Network failures
- Rate limiting
### Edge Case Tests
Test boundary conditions:
- Clock skew tolerance boundaries
- Unicode/emoji in claims
- Very large claim values
- Concurrent access
- Special characters in URLs
## Best Practices
1. **Use fixtures for token generation** - Don't manually construct JWTs
2. **Use mock servers for integration tests** - Test against realistic OIDC behavior
3. **Always run with `-race`** - Catch concurrency issues early
4. **Use testify assertions** - Better error messages and cleaner code
5. **Clean up resources** - Use `t.Cleanup()` or `TearDownTest()`
6. **Test edge cases systematically** - Use table-driven tests
-163
View File
@@ -1,163 +0,0 @@
# Google OAuth Integration Fix
## Problem Overview
The Traefik OIDC plugin encountered an authentication issue when using Google as an OAuth provider. Authentication would fail with the following error:
```
Some requested scopes were invalid. {valid=[openid, https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile], invalid=[offline_access]}
```
This occurred because Google's OAuth implementation differs from the standard OIDC specification in how it handles refresh tokens and offline access.
## Technical Details of the Issue
### Standard OIDC Provider Behavior
Most OpenID Connect (OIDC) providers follow the standard specification, where:
- To obtain a refresh token, clients include the `offline_access` scope in their authorization request
- This allows authenticated sessions to persist beyond the initial access token expiration
### Google's Non-Standard Approach
Google's OAuth implementation deviates from the standard by:
1. Not supporting the `offline_access` scope, instead rejecting it as an invalid scope
2. Requiring the `access_type=offline` query parameter for requesting refresh tokens
3. Needing the `prompt=consent` parameter to consistently issue refresh tokens (especially for repeat authentications)
This difference caused the plugin to fail when configured for Google OAuth, as it was using a standard approach that didn't work with Google's implementation.
## Solution Implementation
The fix involved modifying the authentication flow to specifically handle Google providers:
1. **Google Provider Detection**: Added code to detect if the OIDC provider is Google based on the issuer URL:
```go
// Check if we're dealing with a Google OIDC provider
isGoogleProvider := strings.Contains(t.issuerURL, "google") ||
strings.Contains(t.issuerURL, "accounts.google.com")
```
2. **Provider-Specific Auth URL Building**: Modified the `buildAuthURL` function to handle Google and non-Google providers differently:
```go
// Handle offline access differently for Google vs other providers
if isGoogleProvider {
// For Google, use access_type=offline parameter instead of offline_access scope
params.Set("access_type", "offline")
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
// Add prompt=consent for Google to ensure refresh token is issued
params.Set("prompt", "consent")
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
} else {
// For non-Google providers, use the offline_access scope
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
}
}
```
3. **Token Refresh Enhancement**: Improved the token refresh logic to better handle Google's behavior, particularly when refresh tokens aren't returned in refresh responses (as Google often uses the same refresh token for multiple requests).
## Why This Approach Works
This solution aligns with Google's OAuth 2.0 documentation which specifies:
1. **Access Type Parameter**: Google's [OAuth 2.0 documentation](https://developers.google.com/identity/protocols/oauth2/web-server#offline) states that to request a refresh token, applications must include `access_type=offline` in the authorization request.
2. **Prompt Parameter**: The [`prompt=consent`](https://developers.google.com/identity/protocols/oauth2/web-server#forceapprovalprompt) parameter forces the consent screen to appear, ensuring a refresh token is issued even if the user has previously granted access.
3. **Scope Validation**: Google strictly validates scopes and rejects non-standard ones like `offline_access`, instead relying on the `access_type` parameter to indicate whether a refresh token should be issued.
By adapting to these Google-specific requirements, the OIDC plugin can now seamlessly work with both standard OIDC providers and Google's OAuth implementation.
## Testing and Verification
Comprehensive tests were implemented to verify the solution:
1. **Provider Detection Test**: Ensures the code correctly identifies Google providers and applies the appropriate parameters.
2. **Auth URL Parameter Tests**: Verifies that:
- For Google providers: `access_type=offline` and `prompt=consent` are included; `offline_access` scope is NOT included
- For non-Google providers: `offline_access` scope IS included; `access_type` parameter is NOT added
3. **Token Refresh Tests**: Validates that Google's token refresh process works correctly, including the preservation of refresh tokens when Google doesn't return a new one.
4. **Integration Test**: Tests the complete authentication flow with a mocked Google provider to ensure all components work together seamlessly.
Sample test case (simplified):
```go
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that access_type=offline was added (not offline_access scope for Google)
if !strings.Contains(authURL, "access_type=offline") {
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
}
// Verify offline_access scope is NOT included for Google providers
if strings.Contains(authURL, "offline_access") {
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
}
// Check that prompt=consent was added
if !strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
}
})
```
## Usage Guidance for Developers
When configuring the Traefik OIDC middleware for Google:
1. **Provider URL**: Use `https://accounts.google.com` as the `providerURL` value
2. **Client Configuration**: Create OAuth 2.0 credentials in the Google Cloud Console:
- Configure the authorized redirect URI to match your `callbackURL` setting
- Ensure your OAuth consent screen is properly configured (especially if you want long-lived refresh tokens)
3. **Configuration Example**:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-google-client-id.apps.googleusercontent.com
clientSecret: your-google-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
scopes:
- openid
- email
- profile
# Note: DO NOT manually add offline_access scope for Google
# The middleware handles this automatically and correctly
```
4. **Troubleshooting**: If sessions still expire prematurely with Google (typically after 1 hour):
- Ensure your Google Cloud OAuth consent screen is set to "External" and "Production" mode (not "Testing" mode, which limits refresh token validity)
- Review your application logs with `logLevel: debug` to check for refresh token errors
- Verify you're using a version of the middleware that includes this fix
## Conclusion
This fix ensures that the Traefik OIDC plugin works seamlessly with Google's OAuth implementation without requiring users to make provider-specific configuration changes. The middleware now intelligently adapts to the provider's requirements, making it more robust and user-friendly while maintaining compatibility with the standard OIDC specification for other providers.
+1525
View File
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,459 @@
# Bearer Token Authentication — Design Spec
- **Date**: 2026-05-18
- **Status**: Design — pending implementation plan
- **Supersedes**: PR #93 (broken implementation; recommended to close in favour of this design)
## 1. Summary
Add an opt-in path that lets API clients (machine-to-machine) authenticate by presenting a signed access token in the `Authorization: Bearer <token>` header, bypassing the cookie-based OIDC redirect flow. Identity, roles, and authorization checks remain consistent with the existing cookie path; the only thing that changes is how the principal is established for that single request.
The feature is implemented by extracting a shared `forwardAuthorized` pipeline from the existing `processAuthorizedRequest`, introducing a `principal` value type, and adding a small bearer-specific entrypoint that builds a principal directly from a verified JWT — without synthesising a fake `SessionData`.
## 2. Motivation
PR #93 attempted this feature by building an in-memory `SessionData` from JWT claims and reusing `processAuthorizedRequest`. The approach has three latent defects:
1. The synthetic session omits `mainSession.Values["user_identifier"]`. `processAuthorizedRequest` reads it via `GetUserIdentifier()`; when empty it bails to `defaultInitiateAuthentication` and issues an OIDC redirect. The feature is non-functional in practice despite the unit test passing.
2. `verifyToken` accepts both ID tokens (audience match against `clientID`) and access tokens. ID tokens are not API credentials; treating them as such is a classic token-confusion vector.
3. `verifyToken` adds JTI to the replay blacklist on first verify. Once the verified-token cache evicts, subsequent reuse of the same bearer token triggers a false-positive replay rejection.
Rather than patch a synthetic-session approach that will keep generating bugs as `SessionData` evolves, this spec replaces it with a cleaner abstraction where session lifecycle and post-auth header injection live in separate units.
## 3. Goals
- Accept `Authorization: Bearer <jwt>` from M2M clients, validate the token, and forward the request downstream with identity headers populated.
- Enforce the same `allowedRolesAndGroups` policy as the cookie path.
- Default-off; safe defaults when enabled (audience required, ID tokens rejected, identifier sanitised).
- No behavioural change to the cookie path. Existing tests must continue to pass without modification.
## 4. Non-Goals
- Human-user / browser flows. Bearer is M2M-only in this iteration.
- Pure opaque access tokens on the bearer path. Tokens must be JWTs; introspection (RFC 7662) is supported *on top of* JWT verification for revocation state, not as a substitute for it.
- mTLS, API keys, or any other auth method. The `principal` abstraction enables them later, but they are not delivered here.
- Per-route bearer configuration. Single middleware-wide setting.
## 5. Decided Requirements
| Topic | Decision |
|---|---|
| Consumer type | Machine-to-machine (M2M) only |
| Token format | JWT only (signature, issuer, audience, exp) |
| Audience | Mandatory when feature enabled; startup fails if `Audience == ""` |
| Token type | Access tokens only; ID tokens explicitly rejected |
| Revocation | JWT-only verification by default; introspection (RFC 7662) opt-in via existing `RequireTokenIntrospection` |
| Identity claim | New `BearerIdentifierClaim` config (string, default `"sub"`). Bearer path reads this claim exclusively; does NOT use `UserIdentifierClaim` (which defaults to `"email"` and drives the cookie path). Resolved value must be a non-empty string. `sub` is mandatory per `jwt.go:416` regardless, so even with a different `BearerIdentifierClaim` the token must still carry a valid `sub`. Decoupling avoids the M2M-vs-human-user identity-claim conflict and the email-spoofing footgun. |
| Identifier sanitisation | Reject value containing any `unicode.IsControl` char, any Unicode bidi-override (U+202AU+202E, U+2066U+2069), leading/trailing whitespace, commas, semicolons, equals signs. Max length 256 bytes. |
| Token classifier | **Reuse existing `detectTokenType(jwt, token)` at `token_manager.go:187-303`** which already handles `nonce`, `typ: at+jwt`, `token_use`, `scope`, and aud-vs-clientID priority. Bearer path rejects any token where `detectTokenType == true` (ID token). Do not invent a parallel classifier. |
| Algorithm pinning | Hard-pin `alg ∈ {RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512}`, enforced **before** JWKS lookup on the bearer path. Prevents wasted JWKS fetches for `alg=none`/HS attacker probes. |
| `kid` hardening | `kid` ≤ 256 bytes, charset `[A-Za-z0-9._\-=]`. Reject before JWKS lookup. |
| Token age | Bearer path enforces `now - iat <= MaxTokenAgeSeconds` (default 86400 / 24h, configurable). Cookie path unchanged. |
| Multi-audience policy | If `aud` is an array (length > 1), require `azp` claim to be present and equal to `clientID`. Single-string `aud` unaffected. |
| Mixed bearer + cookie precedence | **Cookie wins by default** when both are presented (safer for browser scenarios). Operator opt-in: `BearerOverridesCookie=true` to flip. Either way, a warning is logged on the request. |
| Bearer + excluded URL | `Authorization` header is **stripped** before forwarding when the request hits an excluded URL. Prevents bearer leaking into public endpoints' downstream logs and prevents recon via excluded paths. |
| Per-source bearer 401 throttle | New sharded cache `failedBearerAttempts` keyed by client IP. After N (default 20) consecutive 401s from one IP within 1 minute, reject further bearer requests from that IP with 429 for 60s. Applied BEFORE `verifyToken` to deny JWKS amplification. |
| `Authorization` header passthrough | New `StripAuthorizationHeader` config, default `true` |
| Roles/groups gating | Same `allowedRolesAndGroups` rules as cookie path |
| Default state | `EnableBearerAuth` = `false` |
| JTI replay marking | Suppressed on bearer path; cookie path unchanged |
| Failure response shape | 401 with generic body; `WWW-Authenticate: Bearer error="invalid_token"` per RFC 6750 |
| Introspection endpoint outage | 503 (distinguishes infra outage from token rejection) |
| Mixed bearer + cookie | Bearer wins; cookie ignored on that request |
| SSE/WS bypass + bearer | Bypass paths keep cookie-only check; bearer header ignored on SSE/WS |
## 6. Architecture
```
┌──────────────────┐
HTTP req ──► │ ServeHTTP │ (existing entry; adds bearer detection)
└─────────┬────────┘
┌───────────┴────────────┐
▼ ▼
cookie / session bearer (Authorization: Bearer …)
│ │
▼ ▼
┌────────────────┐ ┌────────────────────┐
│ buildPrincipal │ │ buildPrincipal │
│ FromSession() │ │ FromBearerToken() │
└────────┬───────┘ └─────────┬──────────┘
│ produces *principal │
└──────────────┬───────────┘
┌────────────────────────────┐
│ forwardAuthorized(rw,req,p)│ (shared pipeline)
│ • roles/groups gate │
│ • header injection │
│ • header templates │
│ • security headers │
│ • cookie stripping │
│ • next.ServeHTTP │
└────────────────────────────┘
```
**Invariant**: `forwardAuthorized` never touches session storage. Session-specific concerns (Save, IsDirty, backchannel-logout invalidation) stay inside `processAuthorizedRequest` around the call to `forwardAuthorized`.
**Feature gate**: when `EnableBearerAuth == false`, the bearer-detection check in `ServeHTTP` is a no-op. Existing deployments observe byte-identical behaviour.
## 7. Components
### 7.1 `principal` type (new file `principal.go`)
```go
type principalSource int
const (
sourceSession principalSource = iota
sourceBearer
)
type principal struct {
Identifier string // drives X-Forwarded-User
Email string // optional, "" for M2M
Subject string // sub claim
ClientID string // azp / client_id, M2M caller
Claims map[string]interface{} // raw claims for templates / groups
AccessToken string // for X-Auth-Request-Token (gated by minimalHeaders)
IDToken string // "" on bearer path
RefreshToken string // "" on bearer path
Source principalSource
}
```
Pure data. No methods that mutate it. No I/O. No manager pointer.
### 7.2 `buildPrincipalFromSession(*SessionData) *principal` (new in `principal.go`)
Read-only adapter over existing `SessionData` getters: `GetUserIdentifier`, `GetEmail`, `GetAccessToken`, `GetIDToken`, `GetRefreshToken`, cached claims via `GetIDTokenClaims`. Does not write back to the session. This is the only function that still knows about `SessionData`.
### 7.3 `buildPrincipalFromBearerToken(token string) (*principal, error)` (new in `bearer_auth.go`)
1. **Length / format guards**: `len(token) <= AccessTokenConfig.MaxLength`, exactly two dots, non-empty after trim.
2. **Parse header for early alg/kid pinning** (without trusting payload): decode JOSE header; reject if `alg` ∉ asymmetric allowlist; reject if `kid` missing, > 256 bytes, or contains chars outside `[A-Za-z0-9._\-=]`. This happens **before** JWKS lookup so attacker noise doesn't amplify into JWKS fetches.
3. **Per-IP 401 throttle check**: if this IP is in the `failedBearerAttempts` penalty box, return 429 immediately.
4. `t.verifyToken(token, verifyOpts{skipReplayMarking: true})` — reuses signature, issuer, audience, expiration, JTI Get (replay detection). The `skipReplayMarking` flag gates ONLY the JTI Set at `token_manager.go:108-143`; the JTI Get at `token_manager.go:44-47, 80-89` remains active so revoked tokens (via `RevokeToken` adding to blacklist) are still rejected.
5. **Re-parse claims** (`parseJWT(token)` is cheap and already done internally; reuse via a single decode if practical).
6. **Token-type guard**: call existing `detectTokenType(jwt, token)` (`token_manager.go:187-303`). Reject when it returns `true` (ID token). Belt-and-braces: also reject if `claims["nonce"]` is a non-empty string or `claims["token_use"] == "id"`.
7. **Multi-audience hardening**: if `claims["aud"]` is a `[]interface{}` with length > 1, require `claims["azp"]` to be a non-empty string equal to `t.clientID`; reject otherwise.
8. **`iat` upper-age bound**: reject when `time.Now().Unix() - int64(claims["iat"].(float64)) > MaxTokenAgeSeconds` (default 86400).
9. **Optional introspection**: if `requireTokenIntrospection` is set, call `introspectToken`; reject if `active == false` (401); surface 503 on transport failure. Bearer-path introspection cache TTL is capped at 60s (not 5min) to keep the "real-time revocation" promise close to true.
10. **Identifier resolution**: read `t.bearerIdentifierClaim` (defaults to `"sub"`); do NOT use `t.userIdentifierClaim` (cookie path's setting, default `email`). The bearer path does NOT fall back to other claims because `jwt.Verify` already enforces non-empty `sub` (`jwt.go:416-419`). Empty/missing identifier → 401.
11. **Identifier sanitisation**: trim, then reject if length > 256 OR contains any of: `unicode.IsControl`, bidi-override (U+202AU+202E, U+2066U+2069), `,`, `;`, `=`.
12. Return `&principal{ Source: sourceBearer, … }`.
On any failure path: increment the per-IP `failedBearerAttempts` counter; return the appropriate HTTP status (401 / 403 / 429 / 503) without revealing the failure reason in the response body. Reason is logged at debug only, with the identifier (if resolved) hashed via SHA-256 truncated to 8 hex chars.
### 7.4 `forwardAuthorized(rw, req, *principal)` (new in `middleware.go`, extracted)
The shared post-auth pipeline. Lifted verbatim from the existing `processAuthorizedRequest`:
1. Roles/groups extraction via existing `extractGroupsAndRolesFromClaims`.
2. `allowedRolesAndGroups` gate (existing logic).
3. Inject `X-Forwarded-User`, `X-User-Groups`, `X-User-Roles`.
4. Inject `X-Auth-Request-*` (gated by `minimalHeaders`).
5. Header templates.
6. Security headers.
7. Cookie strip when `stripAuthCookies`.
8. **New**: `Authorization` header strip when `stripAuthorizationHeader` AND `principal.Source == sourceBearer`.
9. `t.next.ServeHTTP(rw, req)`.
Does not call `Save`, does not check `IsDirty`. Session persistence stays with the cookie-path caller.
### 7.5 `handleBearerRequest(rw, req)` (new in `bearer_auth.go`)
```
1. Detect "Authorization: Bearer <token>" (case-insensitive prefix).
2. token = TrimSpace(authHeader[7:]); reject empty.
3. p, err := buildPrincipalFromBearerToken(token).
On err → 401 with WWW-Authenticate, log reason at debug.
4. forwardAuthorized(rw, req, p).
```
Target: ~40 lines.
### 7.6 Refactor of `processAuthorizedRequest` (modify `middleware.go`)
Splits along the principal boundary:
- Session-specific part (backchannel-logout invalidation, `IsDirty` / `Save`) stays in `processAuthorizedRequest`.
- Everything else moves to `forwardAuthorized`.
- `processAuthorizedRequest` ends with `forwardAuthorized(rw, req, buildPrincipalFromSession(session))`.
### 7.7 `verifyOpts` extension to `verifyToken` (modify `token_manager.go`)
Add a parameter struct:
```go
type verifyOpts struct {
skipReplayMarking bool // suppress JTI Set (token_manager.go:108-143); blacklist Get stays active
}
```
Both the type and field are unexported (internal-only knob). Signature change: `verifyToken(token string)` becomes `verifyToken(token string, opts verifyOpts)`. Existing callers pass `verifyOpts{}` (zero value = current behaviour). Bearer path passes `verifyOpts{skipReplayMarking: true}`.
**Critical semantics — must be reflected in implementation and tests:**
- `skipReplayMarking` only gates the **Set** at `token_manager.go:108-143` (the call adding the JTI to the blacklist and replay cache).
- The blacklist **Get** at `token_manager.go:44-47, 80-89` stays unconditionally active on the bearer path. Tokens revoked via `RevokeToken` (which adds the JTI to the blacklist) MUST still be rejected on the bearer path.
- Must NOT be implemented by mutating `t.disableReplayDetection` (struct field) — that would create a cross-request race that disables replay protection globally.
A targeted regression test exercises: bearer token verified once → admin calls `RevokeToken` adding the JTI to the blacklist → same token replayed → 401.
### 7.8 Config additions (modify `settings.go`)
```go
EnableBearerAuth bool `json:"enableBearerAuth,omitempty"`
BearerIdentifierClaim string `json:"bearerIdentifierClaim,omitempty"`
StripAuthorizationHeader bool `json:"stripAuthorizationHeader,omitempty"`
BearerEmitWWWAuthenticate bool `json:"bearerEmitWWWAuthenticate,omitempty"`
BearerOverridesCookie bool `json:"bearerOverridesCookie,omitempty"`
MaxTokenAgeSeconds int64 `json:"maxTokenAgeSeconds,omitempty"`
MaxIdentifierLength int `json:"maxIdentifierLength,omitempty"`
BearerFailureThreshold int `json:"bearerFailureThreshold,omitempty"`
BearerFailureWindowSeconds int `json:"bearerFailureWindowSeconds,omitempty"`
BearerFailurePenaltySeconds int `json:"bearerFailurePenaltySeconds,omitempty"`
```
Defaults (applied in `CreateConfig` for the bearer-related fields; values >0 only honoured when `EnableBearerAuth=true`):
- `EnableBearerAuth`: `false`.
- `BearerIdentifierClaim`: `"sub"`.
- `StripAuthorizationHeader`: `true`.
- `BearerEmitWWWAuthenticate`: `true` (RFC 6750 hint enabled by default; flip to false if recon-exposure is a concern).
- `BearerOverridesCookie`: `false` (cookie wins when both present; flip to `true` for the legacy/industry-default behaviour).
- `MaxTokenAgeSeconds`: `86400` (24h upper bound on `iat`).
- `MaxIdentifierLength`: `256`.
- `BearerFailureThreshold`: `20` (consecutive 401s per IP before throttle).
- `BearerFailureWindowSeconds`: `60`.
- `BearerFailurePenaltySeconds`: `60` (429 reply for this long after threshold tripped).
### 7.9 Startup validation (modify `main.go` `New()`)
- `EnableBearerAuth && Audience == ""` → fatal error.
- `EnableBearerAuth && !StrictAudienceValidation` → warning log (recommended hardening).
- `EnableBearerAuth && BearerIdentifierClaim == "email"` → fatal error (the bearer path is M2M and an `email` identifier without `email_verified` enforcement is a spoofing vector; default `BearerIdentifierClaim=sub` avoids this; explicit override to `email` is rejected).
- `EnableBearerAuth && MaxTokenAgeSeconds <= 0` → reset to default 86400 with info log.
- `EnableBearerAuth && BearerFailureThreshold <= 0` → reset to default 20 with info log.
## 8. Data Flow
### 8.1 Bearer path
```
ServeHTTP entry (pre-init paths unchanged: logout, backchannel, frontchannel, excluded URLs, SSE/WS bypass)
├─ enableBearerAuth == false? → fall through to cookie path
└─ enableBearerAuth == true AND Authorization starts with "Bearer "
handleBearerRequest
├─ format guards (empty, length, segment count)
verifyToken(token, verifyOpts{SkipReplayMarking: true})
│ signature, issuer, audience (strict), exp
classifyToken(claims) → reject ID tokens
if requireTokenIntrospection: introspectToken → active check
resolveIdentifier(claims) → sanitiseIdentifier
principal{Source: sourceBearer, …}
forwardAuthorized(rw, req, principal)
├─ roles/groups gate (403 on deny)
├─ header injection
├─ header templates
├─ security headers
├─ strip OIDC cookies (existing)
├─ strip Authorization header (new, when configured)
└─ next.ServeHTTP(rw, req)
```
### 8.2 Cookie path (refactored, semantically unchanged)
```
processAuthorizedRequest
1. Session validity / backchannel-logout invalidation (unchanged).
2. principal := buildPrincipalFromSession(session).
3. forwardAuthorized(rw, req, principal).
4. if session.IsDirty(): session.Save().
```
## 9. Error Handling
| Trigger | Status | Body | WWW-Authenticate | Debug log reason |
|---|---|---|---|---|
| Empty bearer after prefix | 401 | `Unauthorized` | `Bearer error="invalid_request"` | empty bearer token |
| Token over MaxLength | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token exceeds max length |
| Not a 3-segment JWT | 401 | `Unauthorized` | `Bearer error="invalid_token"` | malformed JWT |
| Disallowed `alg` (e.g. none, HS*) | 401 | `Unauthorized` | `Bearer error="invalid_token"` | unsupported alg |
| Missing/oversized/bad-charset `kid` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | invalid kid |
| Signature / issuer / aud / exp fail | 401 | `Unauthorized` | `Bearer error="invalid_token"` | reason from verifyToken (category only) |
| `iat` older than MaxTokenAgeSeconds | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token too old (iat outside age bound) |
| Multi-aud without matching `azp` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | multi-aud token without azp match |
| Detected as ID token | 401 | `Unauthorized` | `Bearer error="invalid_token"` | ID tokens not accepted on bearer path |
| JTI blacklisted (revoked) | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token JTI in blacklist |
| Introspection `active=false` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token inactive at IdP |
| Introspection endpoint failure | 503 | `Service Unavailable` | (none) | introspection unavailable |
| Identifier claim missing/empty | 401 | `Unauthorized` | `Bearer error="invalid_token"` | no identifier claim |
| Identifier fails sanitisation | 401 | `Unauthorized` | `Bearer error="invalid_token"` | invalid identifier characters |
| Per-IP failure threshold tripped | 429 | `Too Many Requests` | (none); `Retry-After: <BearerFailurePenaltySeconds>` | source IP in penalty box |
| Roles/groups not allowed | 403 | `Access denied` | (none) | user not in allowedRolesAndGroups |
Responses never include token contents, never include the raw failure reason, and never set `Location` headers (API clients cannot follow redirects).
## 10. Edge Cases
1. **Both bearer header and cookie session present.** Cookie wins by default (safer against browser/extension/proxy bearer injection). `BearerOverridesCookie=true` flips to bearer-wins. Either way: WARN log includes both source markers so operators can audit.
2. **`Authorization: Basic …`.** Not bearer; cookie path runs as today.
3. **`Authorization: Bearer ` (trailing space, no value).** Empty after trim → 401.
4. **Mixed-case prefix (`bearer`, `BEARER`, `BeArEr`).** Case-insensitive prefix check; token value preserved verbatim.
5. **Multiple `Authorization` headers.** Use only the first (Go `http.Header.Get` default). Documented.
6. **Bearer during OIDC init wait.** Bearer requests also block on init: we need `issuerURL`, `audience`, JWKs ready. If init fails, bearer requests return 503 just like cookie requests.
7. **SSE / WebSocket bypass with bearer.** Bypass paths keep cookie-only behaviour. Operators who want bearer on streaming endpoints must remove SSE/WS bypass. Documented.
8. **Logout endpoint with bearer.** Logout runs before bearer detection. Treated as cookie-session logout; bearer token revocation requires IdP-side action.
9. **Excluded URLs with bearer.** Bypass excluded URLs as today; bearer not validated on excluded paths. ADDITIONALLY: `Authorization: Bearer` is stripped from the request before forwarding so the token can't leak into the excluded endpoint's downstream logs / metrics scrapers / health checks.
10. **Concurrent identical bearer requests.** Existing `tokenCache` is concurrency-safe; no new locking.
11. **Client rotates token between requests.** Independent verification per token; independent cache entries.
12. **Clock skew.** Use existing `jwt.Verify` leeway. (If absent, add ±30s as a separate change; out of scope here.)
## 11. Testing Strategy
### 11.1 Integration tests (new `bearer_auth_test.go`)
Table-driven test against a real `httptest.Server` and the full `ServeHTTP` flow. Coverage matrix:
- Valid access token + allowed roles → 200, `next` ran, `X-Forwarded-User` set.
- Valid token without configured roles → 200.
- Wrong audience, expired, tampered signature → 401, `next` did not run.
- ID token presented → 401 (`ID tokens not accepted`).
- Malformed JWT (2 segments) → 401.
- Oversized token (> MaxLength) → 401.
- Empty bearer → 401.
- Missing identifier claim → 401.
- Identifier containing `\r\n` → 401.
- `allowedRolesAndGroups` mismatch → 403.
- `allowedRolesAndGroups` match → 200.
- `EnableBearerAuth=false` + bearer header → cookie path runs (302 to `/authorize`).
- Bearer + valid cookie session → bearer wins, 200.
- `StripAuthorizationHeader=true` → downstream sees no `Authorization`.
- `StripAuthorizationHeader=false` → downstream sees `Authorization`.
- Case variants (`bearer`, `BEARER`) → 200.
- SSE bypass + bearer → cookie-only check applies (bearer ignored).
- **Replay regression**: same token 1000 times in a row → all 200.
- **Cache-evict regression**: same token, force-evict `tokenCache` between iterations (call `tokenCache.Delete` directly), replay → still 200 (verifies `skipReplayMarking` doesn't poison the blacklist).
- **Revocation-while-bearer regression**: bearer token verified once → admin calls `RevokeToken` adding JTI to blacklist → same token presented → 401 (verifies blacklist Get stays active on bearer path even with `skipReplayMarking` set).
- **Alg-pin: token signed with `alg=none`** → 401, no JWKS fetch happens (verify with a counting mock).
- **`kid` injection: 50KB random kid** → 401 immediately, no JWKS fetch.
- **Per-IP throttle**: 21 bad bearer requests from same IP within 1 minute → 22nd returns 429 + Retry-After.
- **`iat` upper-age**: token with `iat = now - 25h` → 401 (older than 24h default).
- **Multi-aud without azp**: aud = `["a", "b"]`, no azp → 401.
- **Multi-aud with matching azp**: aud = `["api-aud", "other"]`, azp = clientID → 200.
- **Identifier with bidi-override**: sub contains U+202E → 401.
- **Identifier with comma**: sub = `"alice,bob"` → 401.
- **Identifier over 256 bytes** → 401.
- **`UserIdentifierClaim=email` at startup with EnableBearerAuth=true** → startup fails.
- **Excluded URL + bearer**: bearer header presented on excluded URL → request forwarded, downstream sees no `Authorization` header (stripped).
### 11.2 Unit tests (in `bearer_auth_test.go`)
- `classifyToken`: ID-token detection, access-token detection by `scope`/`scp`/`token_use`, ambiguous → reject.
- `resolveIdentifier`: precedence (`userIdentifierClaim``sub``client_id`/`azp`); missing → error; empty string → error.
- `sanitizeIdentifier`: rejects all `unicode.IsControl`; accepts email/sub-style values.
### 11.3 Introspection tests (`bearer_auth_introspection_test.go`)
- Token valid + introspection `active=true` → 200.
- Token valid + introspection `active=false` → 401.
- Introspection endpoint 500 → 503.
- Second request hits introspection cache (no second HTTP call).
### 11.4 Startup validation tests (extend `settings_test.go` / `main_test.go`)
- `EnableBearerAuth=true, Audience=""``New()` errors.
- `EnableBearerAuth=true, StrictAudienceValidation=false` → succeeds with warning.
- `EnableBearerAuth=false` → no validation; existing tests untouched.
### 11.5 Cookie-path regression suite
- All existing `TestServeHTTP_*` tests in `main_servehttp_test.go` pass unmodified.
- Add: cookie session, `EnableBearerAuth=true`, no bearer header → identical behaviour to baseline.
- Add: dirty session still triggers `Save()` after refactor.
### 11.6 Principal invariants
- `buildPrincipalFromSession`: `Source == sourceSession`; `IDToken` / `RefreshToken` populated when present in session.
- `buildPrincipalFromBearerToken`: `Source == sourceBearer`; `IDToken == ""`, `RefreshToken == ""`.
- `forwardAuthorized` produces identical headers for equivalent principals regardless of source.
### 11.7 Coverage gate
- New code in `bearer_auth.go` and `principal.go`: ≥ 90% line coverage.
- `forwardAuthorized` coverage ≥ existing `processAuthorizedRequest` coverage baseline.
### 11.8 Out of scope (follow-ups)
- Load test of bearer vs cookie hot path.
- Fuzzing the JWT parser.
- Additional auth methods (mTLS, API keys) — design enables them, but they are separate work.
## 12. Migration / Rollout
Default-off. Existing deployments observe no behavioural change. Operators opt in by setting:
```yaml
enableBearerAuth: true
audience: https://api.example.com # required when bearer enabled
# optional:
stripAuthorizationHeader: true # default
requireTokenIntrospection: false # default; set true for real-time revocation
userIdentifierClaim: client_id # optional override; defaults to sub fallback chain
```
Documentation: update `docs/CONFIGURATION.md` with a bearer-auth section, and add a new `docs/BEARER_AUTH.md` covering the security model, threat assumptions (token issuer is trusted; audience must be set; bearer means trust the issuer's revocation policy unless introspection enabled), and recommended configurations for common IdPs.
## 13. Security Considerations
| Concern | Mitigation |
|---|---|
| Token confusion (ID token used as bearer) | Reuse `detectTokenType` (`token_manager.go:187-303`) which checks `nonce`, `typ: at+jwt`, `token_use`, `scope`, aud-vs-clientID. Belt-and-braces: explicit `nonce` + `token_use == "id"` rejection on top. |
| Audience confusion (token for service B accepted by A) | `Audience` mandatory at startup; verified via existing `VerifyJWTSignatureAndClaims`; multi-aud tokens require matching `azp == clientID`. |
| Replay-via-blacklist false positive | `verifyOpts{skipReplayMarking: true}` on bearer path. Gates ONLY the Set; the Get stays so revoked tokens still fail. |
| Revocation lag | Optional RFC 7662 introspection. Bearer-path introspection cache TTL capped at 60s. Set `RequireTokenIntrospection=true` for real-time revocation. |
| `alg`-confusion / `alg=none` attacks | Hard-pin asymmetric allowlist at bearer entry, **before** JWKS fetch. Prevents wasted upstream calls and locks out HS/none probes. |
| `kid` injection / JWKS amplification | `kid` length cap (256 bytes) + charset allowlist enforced at bearer entry. |
| Bearer 401 brute-force / oracle | Per-IP `failedBearerAttempts` cache; configurable threshold + penalty box returning 429 + `Retry-After`. |
| `iat` clock-manipulation / forever-tokens | `MaxTokenAgeSeconds` upper bound (default 24h); cookie path unchanged. |
| Identifier-driven header injection | `sanitizeIdentifier`: length cap, control-char + bidi-override + `,;=` rejection. `net/http` rejects CRLF on the wire too (defence in depth). |
| Token leakage downstream | `StripAuthorizationHeader=true` by default. Also: `Authorization` stripped on excluded-URL requests so bearer can't leak into health/metrics downstream logs. |
| Token-in-logs | All log paths log reason categories, not raw tokens. Identifier hashed via SHA-256 truncated to 8 hex chars before any info/warn-level emission (full identifier only at debug). New `safeLogAuthEvent(category, hashedIdentifier, reasonCode)` helper makes this hard to misuse. |
| `email` claim spoofing | Startup fails if `EnableBearerAuth && UserIdentifierClaim == "email"`. Future human-user bearer iteration must add `email_verified` enforcement. |
| Bypass on SSE / WS endpoints | SSE/WS bypass keeps cookie-only behaviour; bearer ignored. Operators choose to widen if needed. |
| Mixed bearer + cookie precedence | Cookie wins by default (safer for browser scenarios); `BearerOverridesCookie=true` flips. WARN log on both-present requests. |
| Configuration drift (operator forgets audience) | Startup fails when `EnableBearerAuth=true && Audience==""`. |
| Downstream blast radius when `StripAuthorizationHeader=false` | Documented: forwarded bearer extends token's blast radius to all downstream services. Logs at those services become token stores. Operators must treat downstream log policy accordingly. |
| Introspection auth method (pre-existing gap, called out) | `token_introspection.go:80` uses `client_secret_basic` only; does not honour `private_key_jwt`. Out of scope for this PR but documented as a follow-up; operators using `ClientAuthMethod=private_key_jwt` + `RequireTokenIntrospection=true` should be aware introspection will use basic auth. |
## 14. Open Questions
None — all design decisions resolved during brainstorming + security review. Implementation may surface incidental questions (e.g. exact clock-skew leeway in `jwt.Verify`); those are out of scope for this spec and handled in the implementation plan.
## 14a. Security Review Reference
This design was reviewed by the `security-reviewer` subagent on 2026-05-18. Findings incorporated:
- **Critical**: C1 (classifier reuses `detectTokenType`), C2 (sub fallback dropped — unreachable due to `jwt.go:416`), C3 (replay-marking gates only Set, not Get; revocation regression test added).
- **High**: H1 (alg pinned at bearer entry), H2 (kid length + charset), H3 (cookie wins by default, configurable), H4 (per-IP 401 throttle), H5 (multi-aud requires azp).
- **Medium**: M1 (identifier max-length + bidi reject + delimiter chars), M2 (introspection cache TTL capped at 60s on bearer path), M4 (log-hashing via SHA-256[:8]), M5 (StripAuth blast-radius documented), M6 (iat upper-age bound), M7 (Authorization stripped on excluded URLs).
- **Low/Nit**: L2 (renamed to `BearerEmitWWWAuthenticate`), N3 (startup rejects `UserIdentifierClaim=email`).
- **Documented as pre-existing gaps (follow-up PRs)**: M3 (introspection auth method doesn't honour `private_key_jwt`).
## 15. Implementation Plan Reference
To be produced by the `writing-plans` skill in a follow-up document at `docs/superpowers/plans/2026-05-18-bearer-token-auth-plan.md`. The plan decomposes this design into ordered, independently-testable PRs.
@@ -0,0 +1,173 @@
# oidcgate — Standalone OIDC Forward-Auth Daemon (Tier 1)
**Date:** 2026-05-19
**Status:** Design approved, pending implementation plan
**Scope:** Tier 1 only — forward-auth daemon. No reverse-proxy mode, no TLS termination, no metrics, no multi-tenancy.
## Goal
Provide a standalone Go binary that exposes the existing `traefikoidc` OIDC middleware as a forward-auth daemon callable from nginx (`auth_request`), Caddy (`forward_auth`), Traefik (`ForwardAuth`), HAProxy, and Envoy (`ext_authz_http`). The library's public surface is **additive only**: no existing exported function signature changes; one optional `Config` field (`TrustForwardedURI`) and one new read-only accessor on `*TraefikOidc` are added. Existing Traefik plugin users see no behavior change unless they opt in to the new field.
## Non-Goals
- Reverse-proxy mode (Tier 2 — separate effort if requested).
- TLS termination inside the daemon.
- Built-in Prometheus metrics.
- Multi-tenant routing (one binary serves one OIDC client config).
- oauth2-proxy CLI/flag compatibility.
- Docker image or goreleaser publishing as part of Tier 1.
## Architecture
### File layout
```
github.com/lukaszraczylo/traefikoidc (library, unchanged)
├── main.go (existing New/NewWithContext)
├── middleware.go (existing ServeHTTP + ~10 LoC patch)
└── cmd/
└── oidcgate/
├── main.go (entrypoint, flags, signal handling)
├── config.go (YAML loader + env-var override walker)
├── server.go (http.ServeMux wiring, listen loop)
├── endpoints.go (auth/start/callback/logout handlers)
├── success.go (synthetic success handler used as `next`)
├── interceptor.go (302→401 response-writer for /oauth2/auth)
├── health.go (/healthz, /readyz)
├── config_test.go
├── endpoints_test.go
└── interceptor_test.go
```
### Process model
Single binary, single `*TraefikOidc` instance built at boot from YAML config, served on one listener port. Health endpoints on the same listener. Graceful shutdown on `SIGINT`/`SIGTERM` via `srv.Shutdown(ctx)` with a 15s deadline, after which context cancellation propagates into the existing goroutine manager.
## Endpoint Contract
All four endpoints share the listener. Paths are configurable; defaults shown.
| Endpoint | Default path | Method | Contract |
|---|---|---|---|
| **Auth probe** | `/oauth2/auth` | GET | Silent. `200 OK` + injected headers on success. `401` on failure (never `302`). Consumed by nginx `auth_request`, Traefik ForwardAuth, Caddy `forward_auth`, Envoy ext_authz_http. |
| **Sign-in** | `/oauth2/start` | GET | Always `302` to IdP `authorize` URL with `state`+`nonce`+PKCE. Reads target URL from `?rd=` query or `X-Forwarded-Uri` header. Hit by the browser after a `401` from `/oauth2/auth`. |
| **Callback** | `config.callbackURL` | GET | IdP `code`+`state` exchange. Existing `auth_flow.go` logic runs unchanged. On success → `302` to original URL. |
| **Logout** | `config.logoutURL` | GET/POST | Existing `logout.go` handler. Terminates session. Honors `oidcEndSessionURL` if configured. |
### Wiring (how each endpoint delegates)
All four handlers feed `(*TraefikOidc).ServeHTTP` after rewriting `req.URL.Path`:
- `/oauth2/callback` → rewrite to `config.callbackURL`, delegate. Middleware path-match at `middleware.go` triggers callback flow.
- `/oauth2/logout` → rewrite to `config.logoutURL`, delegate. Middleware logout path-match at the top of `ServeHTTP` triggers.
- `/oauth2/start` → rewrite `req.URL.Path` to the sentinel `/__oidcgate_protected__`, delegate. Middleware sees an unauthenticated GET on a protected path, emits the `302` to IdP. The redirect flows through naturally.
- `/oauth2/auth` → rewrite `req.URL.Path` to `/__oidcgate_protected__`, wrap `ResponseWriter` with an interceptor (see below), delegate.
The sentinel path `/__oidcgate_protected__` is chosen because it cannot collide with `callbackURL` / `logoutURL` / `/health*` path matches inside `ServeHTTP` and is not a likely user-configured `excludedURLs` entry. It is internal-only: clients never see it.
### Synthetic `next` handler
The middleware calls `t.next.ServeHTTP(rw, req)` at four sites (`middleware.go:174,185,187,592`) when the request is authenticated and should be forwarded. The daemon supplies a `next` that:
1. Writes `200 OK`.
2. Mirrors any `X-Forwarded-*` and templated headers that the middleware set on `req.Header` (e.g. `X-Forwarded-User` at `middleware.go:101,512`) onto the **response** headers, so proxies can capture them via `auth_request_set` / `authResponseHeaders`.
3. Writes empty body.
### 302 → 401 interceptor (`/oauth2/auth` only)
nginx `auth_request` cannot follow `302`s. For the silent endpoint, the daemon wraps the `ResponseWriter` such that:
- If the middleware writes status `302` (the IdP-redirect branch), the interceptor rewrites it to `401`.
- `Location` header from the swallowed `302` is preserved as `X-Auth-Redirect` on the `401` response (advisory; some proxies may surface it).
- `Set-Cookie` headers (state, PKCE, nonce) are preserved verbatim so the browser carries them into the subsequent `/oauth2/start` request.
- For any non-`302` status, the interceptor is a passthrough.
`/oauth2/start` does **not** wrap; the middleware's natural `302` flows through to the browser.
## Configuration
### Source
- **File:** `--config /etc/oidcgate/config.yaml` (path overridable via flag).
- **Format:** YAML, unmarshalled into the existing `traefikoidc.Config` struct (`settings.go:39`) via `yaml.v3` (already a dependency in `go.mod`).
- **Migration from Traefik:** copy the `plugin.traefikoidc:` subtree out of `.traefik.yml` and add the daemon-specific top-level keys below.
- **Env-var overrides** (secrets in particular): after YAML unmarshal, walk the config struct. Any scalar string/int/bool field with a non-empty `OIDCGATE_<UPPER_SNAKE_CASE_FIELD>` env-var replaces the YAML value. Nested structs (`Redis`, `SecurityHeaders`, `DynamicClientRegistration`) stay YAML-only.
### Top-level oidcgate-specific keys
```yaml
listen: ":8080" # required, listener address
authPath: "/oauth2/auth" # optional, default shown
startPath: "/oauth2/start" # optional, default shown
# all other keys = existing traefikoidc.Config fields
```
### Validation
The existing validation inside `traefikoidc.NewWithContext` (`main.go:97`) runs unchanged. On any returned error, the daemon logs the error and exits non-zero.
## Library-Side Patches
Two additive changes, both default-off / read-only:
1. **`settings.go` + `middleware.go` (~10 LoC):** add `Config.TrustForwardedURI bool` (default `false`). When `true`, the post-login-redirect target captured during the "unauthenticated GET → 302 to IdP" branch is sourced from `req.Header.Get("X-Forwarded-Uri")` if non-empty, instead of from `req.URL`. The daemon sets `TrustForwardedURI = true` at config build time. Default-off preserves current Traefik plugin behavior exactly.
2. **`main.go` (~5 LoC):** add `func (t *TraefikOidc) Ready() bool` returning `true` once at least one successful OIDC metadata discovery fetch has populated the metadata cache. Read-only; no behavior change for existing consumers.
## Lifecycle
```
parse flags
→ load YAML
→ apply env-var overrides
→ build synthetic success handler
→ call traefikoidc.New(ctx, success, cfg, "oidcgate") (validation happens here)
→ build mux (auth, start, callback, logout, healthz, readyz)
→ http.Server.ListenAndServe on cfg.listen
→ wait for SIGINT/SIGTERM
→ srv.Shutdown(15s ctx)
→ ctx cancel propagates to goroutine manager
→ exit 0
```
`/readyz` returns `200` only once `traefikoidc.New` has returned without error **and** the first OIDC metadata discovery fetch has succeeded. Implementation: add a read-only accessor `func (t *TraefikOidc) Ready() bool` that returns `true` once the metadata cache has at least one successful discovery fetch. The daemon's `/readyz` handler calls this and returns `200` / `503` accordingly.
`/healthz` returns `200` as long as the process is alive.
## Testing
- `config_test.go` — YAML round-trip; env-var override precedence; validation pass-through.
- `endpoints_test.go``httptest.NewServer`-based scenarios:
- `/oauth2/auth` with no session → `401`.
- `/oauth2/auth` with valid session → `200` + `X-Forwarded-User` mirrored on response.
- `/oauth2/start``302` with valid IdP `authorize` URL incl. `state` and PKCE.
- `/oauth2/callback` → completes exchange, sets session, redirects to original URL.
- `/oauth2/logout` → clears session cookie.
- `interceptor_test.go` — middleware-emitted `302` becomes `401`; `Location``X-Auth-Redirect`; `Set-Cookie` preserved.
- Reuse existing mock IdP from `enhanced_mocks_test.go`. No new mock infra.
## Risks & Mitigations
| Risk | Mitigation |
|---|---|
| Interceptor swallows a legitimate `302` from the middleware that isn't the IdP redirect. | Inspect: only intercept when `Location` matches `t.providerURL`'s authorize endpoint or when it points off-host. Test coverage in `interceptor_test.go`. |
| Library-side patch breaks current Traefik users. | New `TrustForwardedURI` defaults to `false`; existing path untouched when unset. |
| Env-var walker overreaches into nested structs. | Restrict to top-level scalar fields; document explicitly; nested structs stay YAML-only. |
| Path-rewrite trick hits a middleware path-comparison we didn't anticipate. | All four `t.next.ServeHTTP` sites verified at `middleware.go:174,185,187,592`. Endpoint tests exercise each path. |
## Out of Scope (Tier 2 candidates)
- Reverse-proxy mode (`httputil.ReverseProxy` as the configured `next`).
- TLS termination (`tls.Config`, ACME).
- Prometheus metrics endpoint.
- Multi-tenant routing.
- oauth2-proxy flag/env compatibility.
- Goreleaser binaries, Docker image, systemd unit.
## Acceptance Criteria
1. `go build ./cmd/oidcgate` produces a working binary.
2. Existing test suite (`go test ./...`) still passes — zero regressions in the library.
3. New endpoint tests pass for all four endpoints against the mock IdP.
4. `oidcgate --config example.yaml` boots, serves `/healthz`, performs end-to-end OIDC flow against a real IdP (manual smoke test against e.g. a local Keycloak).
5. README section documents nginx, Caddy, and Traefik wiring examples.
+609
View File
@@ -0,0 +1,609 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
)
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
type ClientRegistrationResponse struct {
SubjectType string `json:"subject_type,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
Scope string `json:"scope,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
TOSURI string `json:"tos_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
ClientID string `json:"client_id"`
ClientName string `json:"client_name,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
}
// ClientRegistrationError represents an error response from client registration (RFC 7591)
type ClientRegistrationError struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
}
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
type DynamicClientRegistrar struct {
httpClient *http.Client
logger *Logger
config *DynamicClientRegistrationConfig
registrationResponse *ClientRegistrationResponse
store DCRCredentialsStore // Storage backend for credentials
providerURL string
mu sync.RWMutex
}
// NewDynamicClientRegistrar creates a new dynamic client registrar
func NewDynamicClientRegistrar(
httpClient *http.Client,
logger *Logger,
dcrConfig *DynamicClientRegistrationConfig,
providerURL string,
) *DynamicClientRegistrar {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &DynamicClientRegistrar{
httpClient: httpClient,
logger: logger,
config: dcrConfig,
providerURL: providerURL,
}
}
// NewDynamicClientRegistrarWithStore creates a new dynamic client registrar with a specific storage backend
func NewDynamicClientRegistrarWithStore(
httpClient *http.Client,
logger *Logger,
dcrConfig *DynamicClientRegistrationConfig,
providerURL string,
store DCRCredentialsStore,
) *DynamicClientRegistrar {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &DynamicClientRegistrar{
httpClient: httpClient,
logger: logger,
config: dcrConfig,
providerURL: providerURL,
store: store,
}
}
// SetStore sets the credentials store for the registrar
// This allows setting the store after creation when the cache manager is available
func (r *DynamicClientRegistrar) SetStore(store DCRCredentialsStore) {
r.mu.Lock()
defer r.mu.Unlock()
r.store = store
}
// RegisterClient performs dynamic client registration with the OIDC provider
// It first attempts to load existing credentials from storage if persistence is enabled,
// then registers a new client if no valid credentials exist.
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
if r.config == nil || !r.config.Enabled {
return nil, fmt.Errorf("dynamic client registration is not enabled")
}
// Try to load existing credentials if persistence is enabled
if r.config.PersistCredentials {
resp, err := r.loadCredentialsFromStore(ctx)
if err != nil {
r.logger.Debugf("Failed to load credentials from store: %v", err)
} else if resp != nil {
// Check if credentials are still valid (not expired)
if r.areCredentialsValid(resp) {
r.logger.Info("Loaded existing client credentials from storage")
r.mu.Lock()
r.registrationResponse = resp
r.mu.Unlock()
return resp, nil
}
r.logger.Info("Existing credentials expired or invalid, registering new client")
}
}
// Determine registration endpoint
endpoint := registrationEndpoint
if r.config.RegistrationEndpoint != "" {
endpoint = r.config.RegistrationEndpoint
}
if endpoint == "" {
return nil, fmt.Errorf("no registration endpoint available: provider does not support dynamic client registration or endpoint not configured")
}
// Validate the endpoint URL
if !strings.HasPrefix(endpoint, "https://") {
// Allow http only for localhost/development
if !strings.HasPrefix(endpoint, "http://localhost") && !strings.HasPrefix(endpoint, "http://127.0.0.1") {
return nil, fmt.Errorf("registration endpoint must use HTTPS for security")
}
r.logger.Infof("Warning: using insecure HTTP for registration endpoint (development only): %s", endpoint)
}
// Build registration request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build registration request: %w", err)
}
r.logger.Debugf("Registering client at endpoint: %s", endpoint)
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create registration request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
// Add Initial Access Token if provided
if r.config.InitialAccessToken != "" {
req.Header.Set("Authorization", "Bearer "+r.config.InitialAccessToken)
}
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
if err != nil {
return nil, fmt.Errorf("failed to read registration response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("registration failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse registration response: %w", err)
}
// Validate response
if regResp.ClientID == "" {
return nil, fmt.Errorf("registration response missing client_id")
}
r.logger.Infof("Successfully registered client with ID: %s", regResp.ClientID)
// Cache the response
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist client credentials: %v", err)
// Don't fail registration if persistence fails
}
}
return &regResp, nil
}
// buildRegistrationRequest creates the JSON request body for client registration
func (r *DynamicClientRegistrar) buildRegistrationRequest() ([]byte, error) {
metadata := r.config.ClientMetadata
if metadata == nil {
metadata = &ClientRegistrationMetadata{}
}
// Build request object
reqData := make(map[string]interface{})
// Required: redirect_uris
if len(metadata.RedirectURIs) > 0 {
reqData["redirect_uris"] = metadata.RedirectURIs
} else {
return nil, fmt.Errorf("redirect_uris is required for client registration")
}
// Optional fields - only include if set
if len(metadata.ResponseTypes) > 0 {
reqData["response_types"] = metadata.ResponseTypes
} else {
// Default to authorization code flow
reqData["response_types"] = []string{"code"}
}
if len(metadata.GrantTypes) > 0 {
reqData["grant_types"] = metadata.GrantTypes
} else {
// Default grant types for authorization code flow
reqData["grant_types"] = []string{"authorization_code", "refresh_token"}
}
if metadata.ApplicationType != "" {
reqData["application_type"] = metadata.ApplicationType
}
if len(metadata.Contacts) > 0 {
reqData["contacts"] = metadata.Contacts
}
if metadata.ClientName != "" {
reqData["client_name"] = metadata.ClientName
}
if metadata.LogoURI != "" {
reqData["logo_uri"] = metadata.LogoURI
}
if metadata.ClientURI != "" {
reqData["client_uri"] = metadata.ClientURI
}
if metadata.PolicyURI != "" {
reqData["policy_uri"] = metadata.PolicyURI
}
if metadata.TOSURI != "" {
reqData["tos_uri"] = metadata.TOSURI
}
if metadata.JWKSURI != "" {
reqData["jwks_uri"] = metadata.JWKSURI
}
if metadata.SubjectType != "" {
reqData["subject_type"] = metadata.SubjectType
}
if metadata.TokenEndpointAuthMethod != "" {
reqData["token_endpoint_auth_method"] = metadata.TokenEndpointAuthMethod
} else {
// Default to client_secret_basic for confidential clients
reqData["token_endpoint_auth_method"] = "client_secret_basic"
}
if metadata.DefaultMaxAge > 0 {
reqData["default_max_age"] = metadata.DefaultMaxAge
}
if metadata.RequireAuthTime {
reqData["require_auth_time"] = metadata.RequireAuthTime
}
if len(metadata.DefaultACRValues) > 0 {
reqData["default_acr_values"] = metadata.DefaultACRValues
}
if metadata.Scope != "" {
reqData["scope"] = metadata.Scope
}
return json.Marshal(reqData)
}
// GetCachedResponse returns the cached registration response
func (r *DynamicClientRegistrar) GetCachedResponse() *ClientRegistrationResponse {
r.mu.RLock()
defer r.mu.RUnlock()
return r.registrationResponse
}
// areCredentialsValid checks if the cached credentials are still valid
func (r *DynamicClientRegistrar) areCredentialsValid(resp *ClientRegistrationResponse) bool {
if resp == nil || resp.ClientID == "" {
return false
}
// Check if secret has expired
if resp.ClientSecretExpiresAt > 0 {
expiresAt := time.Unix(resp.ClientSecretExpiresAt, 0)
// Add 5 minute buffer before expiration
if time.Now().Add(5 * time.Minute).After(expiresAt) {
return false
}
}
return true
}
// credentialsFilePath returns the path for storing credentials
func (r *DynamicClientRegistrar) credentialsFilePath() string {
if r.config.CredentialsFile != "" {
return r.config.CredentialsFile
}
return "/tmp/oidc-client-credentials.json"
}
// loadCredentialsFromStore loads client credentials from the configured storage backend
// Falls back to legacy file-based loading if no store is configured
func (r *DynamicClientRegistrar) loadCredentialsFromStore(ctx context.Context) (*ClientRegistrationResponse, error) {
// Use store if available
if r.store != nil {
return r.store.Load(ctx, r.providerURL)
}
// Fallback to legacy file-based loading
return r.loadCredentials()
}
// saveCredentialsToStore persists client credentials to the configured storage backend
// Falls back to legacy file-based saving if no store is configured
func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, resp *ClientRegistrationResponse) error {
// Use store if available
if r.store != nil {
return r.store.Save(ctx, r.providerURL, resp)
}
// Fallback to legacy file-based saving
return r.saveCredentials(resp)
}
// deleteCredentialsFromStore removes credentials from the configured storage backend
// Falls back to legacy file-based deletion if no store is configured
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
// Use store if available
if r.store != nil {
return r.store.Delete(ctx, r.providerURL)
}
// Fallback to legacy file-based deletion
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// saveCredentials persists client credentials to a file (legacy method)
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
filePath := r.credentialsFilePath()
data, err := json.MarshalIndent(resp, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Write with restrictive permissions (owner read/write only)
if err := os.WriteFile(filePath, data, 0600); err != nil {
return fmt.Errorf("failed to write credentials file: %w", err)
}
r.logger.Debugf("Saved client credentials to %s", filePath)
return nil
}
// loadCredentials loads client credentials from a file (legacy method)
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
filePath := r.credentialsFilePath()
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
data, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // No credentials file exists
}
return nil, fmt.Errorf("failed to read credentials file: %w", err)
}
var resp ClientRegistrationResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
}
return &resp, nil
}
// UpdateClientRegistration updates an existing client registration using RFC 7592
// This requires the registration_client_uri and registration_access_token from the original registration
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to update")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Build update request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build update request: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create update request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("update request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read update response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse update response: %w", err)
}
// Update cache
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist updated credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist updated credentials: %v", err)
}
}
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
return &regResp, nil
}
// ReadClientRegistration reads the current client registration using RFC 7592
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to read")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
if err != nil {
return nil, fmt.Errorf("failed to create read request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("read request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse read response: %w", err)
}
return &regResp, nil
}
// DeleteClientRegistration deletes the client registration using RFC 7592
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return fmt.Errorf("no existing registration to delete")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
if err != nil {
return fmt.Errorf("failed to create delete request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return fmt.Errorf("delete request failed: %w", err)
}
defer resp.Body.Close()
// Handle error responses (204 No Content is success)
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
}
// Clear cache
r.mu.Lock()
r.registrationResponse = nil
r.mu.Unlock()
// Remove credentials from storage if persistence is enabled
if r.config.PersistCredentials {
if err := r.deleteCredentialsFromStore(ctx); err != nil {
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
}
}
r.logger.Info("Successfully deleted client registration")
return nil
}
File diff suppressed because it is too large Load Diff
+620
View File
@@ -0,0 +1,620 @@
package traefikoidc
import (
"context"
"encoding/base64"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/testutil"
"github.com/stretchr/testify/suite"
"golang.org/x/time/rate"
)
// ClockSkewEdgeCasesSuite tests clock skew tolerance scenarios
type ClockSkewEdgeCasesSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *ClockSkewEdgeCasesSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *ClockSkewEdgeCasesSuite) SetupTest() {
// Create JWK for the test key
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error") // Reduce noise
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *ClockSkewEdgeCasesSuite) TestExactlyAtExpiry() {
token, err := s.fixture.TokenWithSkew(0)
s.Require().NoError(err)
// Token at exact expiry - behavior is implementation-defined
err = s.tOidc.VerifyToken(token)
s.T().Logf("Exact expiry result: %v", err)
}
func (s *ClockSkewEdgeCasesSuite) TestOneSecondBeforeExpiry() {
token, err := s.fixture.TokenWithSkew(1 * time.Second)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token should be valid 1 second before expiry")
}
func (s *ClockSkewEdgeCasesSuite) TestOneSecondAfterExpiry() {
token, err := s.fixture.TokenWithSkew(-1 * time.Second)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
// With default 2-minute clock skew tolerance, 1 second past expiry should still be valid
s.NoError(err, "Token 1 second past expiry should be valid within clock skew tolerance")
}
func (s *ClockSkewEdgeCasesSuite) TestWithinSkewTolerance() {
// Most implementations allow 5-minute clock skew
token, err := s.fixture.TokenWithSkew(-4 * time.Minute)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
// May pass or fail depending on implementation
s.T().Logf("4-minute expired token result: %v", err)
}
func (s *ClockSkewEdgeCasesSuite) TestBeyondSkewTolerance() {
token, err := s.fixture.TokenWithSkew(-10 * time.Minute)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.Error(err, "Token should be invalid 10 minutes after expiry")
}
func TestClockSkewEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(ClockSkewEdgeCasesSuite))
}
// UnicodeClaimsSuite tests Unicode handling in JWT claims
type UnicodeClaimsSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *UnicodeClaimsSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *UnicodeClaimsSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *UnicodeClaimsSuite) TestUnicodeEmail() {
token, err := s.fixture.TokenWithEmail("用户@example.com")
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Unicode email should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestUnicodeName() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "田中太郎",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Unicode name should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestEmojiInClaims() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "Test User 😀",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Emoji in claims should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestRTLText() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "مستخدم اختبار",
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "RTL text should be handled correctly")
}
func (s *UnicodeClaimsSuite) TestMixedScripts() {
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"name": "Test 测试 テスト",
"roles": []string{"admin", "管理者", "管理员"},
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Mixed scripts should be handled correctly")
}
func TestUnicodeClaimsSuite(t *testing.T) {
suite.Run(t, new(UnicodeClaimsSuite))
}
// LargeClaimsSuite tests large claim values
type LargeClaimsSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *LargeClaimsSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *LargeClaimsSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *LargeClaimsSuite) TestManyRoles() {
roles := make([]string, 100)
for i := 0; i < 100; i++ {
roles[i] = strings.Repeat("role", 10) + string(rune('A'+i%26))
}
token, err := s.fixture.TokenWithRoles(roles)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with 100 roles should be handled")
}
func (s *LargeClaimsSuite) TestManyGroups() {
groups := make([]string, 50)
for i := 0; i < 50; i++ {
groups[i] = strings.Repeat("group", 5) + string(rune('A'+i%26))
}
token, err := s.fixture.TokenWithGroups(groups)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with 50 groups should be handled")
}
func (s *LargeClaimsSuite) TestLongEmail() {
// RFC 5321 allows up to 254 characters
localPart := strings.Repeat("a", 64)
domain := strings.Repeat("b", 63) + ".com"
email := localPart + "@" + domain
token, err := s.fixture.TokenWithEmail(email)
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with long email should be handled")
}
func (s *LargeClaimsSuite) TestLongSubject() {
longSub := strings.Repeat("subject", 100)
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"sub": longSub,
})
s.Require().NoError(err)
err = s.tOidc.VerifyToken(token)
s.NoError(err, "Token with long subject should be handled")
}
func TestLargeClaimsSuite(t *testing.T) {
suite.Run(t, new(LargeClaimsSuite))
}
// URLPathEdgeCasesSuite tests URL handling edge cases
type URLPathEdgeCasesSuite struct {
suite.Suite
}
func (s *URLPathEdgeCasesSuite) TestVeryLongPath() {
longPath := "/" + strings.Repeat("segment/", 100)
req := httptest.NewRequest("GET", longPath, nil)
s.NotNil(req)
s.Contains(req.URL.Path, "segment")
}
func (s *URLPathEdgeCasesSuite) TestSpecialCharactersInPath() {
paths := []string{
"/path%20with%20spaces",
"/path/with/日本語",
"/path?query=value&another=test",
"/path#fragment",
"/path/../traversal",
"/path/./current",
}
for _, path := range paths {
s.Run(path, func() {
req := httptest.NewRequest("GET", path, nil)
s.NotNil(req)
})
}
}
func (s *URLPathEdgeCasesSuite) TestEmptyPath() {
req := httptest.NewRequest("GET", "/", nil)
s.Equal("/", req.URL.Path)
}
func (s *URLPathEdgeCasesSuite) TestDoubleSlashes() {
req := httptest.NewRequest("GET", "//double//slashes//", nil)
s.NotNil(req)
}
func TestURLPathEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(URLPathEdgeCasesSuite))
}
// ConcurrencyEdgeCasesSuite tests concurrency scenarios
type ConcurrencyEdgeCasesSuite struct {
suite.Suite
fixture *testutil.TokenFixture
tOidc *TraefikOidc
}
func (s *ConcurrencyEdgeCasesSuite) SetupSuite() {
var err error
s.fixture, err = testutil.NewTokenFixture()
s.Require().NoError(err)
}
func (s *ConcurrencyEdgeCasesSuite) SetupTest() {
jwk := JWK{
Kty: "RSA",
Kid: s.fixture.KeyID,
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
}
jwkCache := &MockJWKCache{
JWKS: &JWKSet{Keys: []JWK{jwk}},
Err: nil,
}
tokenBlacklist := NewCache()
tokenCacheInternal := NewCache()
tokenCache := &TokenCache{}
if tokenCache.cache == nil {
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
tokenCache.cache = wrapper.cache
}
}
logger := NewLogger("error")
s.tOidc = &TraefikOidc{
issuerURL: s.fixture.Issuer,
clientID: s.fixture.Audience,
audience: s.fixture.Audience,
clientSecret: "test-client-secret",
roleClaimName: "roles",
groupClaimName: "groups",
userIdentifierClaim: "email",
jwkCache: jwkCache,
jwksURL: "https://test-jwks-url.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Higher limit for concurrency tests
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
goroutineWG: &sync.WaitGroup{},
ctx: context.Background(),
}
close(s.tOidc.initComplete)
s.tOidc.tokenVerifier = s.tOidc
s.tOidc.jwtVerifier = s.tOidc
s.T().Cleanup(func() {
if s.tOidc.tokenBlacklist != nil {
s.tOidc.tokenBlacklist.Close()
}
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
s.tOidc.tokenCache.cache.Close()
}
})
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentTokenValidation() {
token, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
const goroutines = 50
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.tOidc.VerifyToken(token); err != nil {
errors <- err
}
}()
}
wg.Wait()
close(errors)
var errCount int
for err := range errors {
s.T().Logf("Concurrent error: %v", err)
errCount++
}
s.Equal(0, errCount, "All concurrent validations should succeed")
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentDifferentTokens() {
const goroutines = 20
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
"custom": idx,
})
if err != nil {
errors <- err
return
}
if err := s.tOidc.VerifyToken(token); err != nil {
errors <- err
}
}(i)
}
wg.Wait()
close(errors)
var errCount int
for err := range errors {
s.T().Logf("Concurrent different token error: %v", err)
errCount++
}
s.Equal(0, errCount, "All concurrent different token validations should succeed")
}
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentMixedValidInvalid() {
validToken, err := s.fixture.ValidToken(nil)
s.Require().NoError(err)
expiredToken, err := s.fixture.ExpiredToken()
s.Require().NoError(err)
const goroutines = 40
var wg sync.WaitGroup
validCount := int32(0)
expiredCount := int32(0)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
var token string
if idx%2 == 0 {
token = validToken
} else {
token = expiredToken
}
err := s.tOidc.VerifyToken(token)
if idx%2 == 0 {
if err == nil {
atomic.AddInt32(&validCount, 1)
}
} else {
if err != nil {
atomic.AddInt32(&expiredCount, 1)
}
}
}(i)
}
wg.Wait()
s.T().Logf("Valid passed: %d, Expired rejected: %d", validCount, expiredCount)
}
func TestConcurrencyEdgeCasesSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyEdgeCasesSuite))
}
+258
View File
@@ -0,0 +1,258 @@
package traefikoidc
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/suite"
)
// EnhancedMocksSuite demonstrates improved state-based mocks with call tracking
type EnhancedMocksSuite struct {
suite.Suite
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheCallTracking() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
// Make some calls
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.NoError(err)
s.NotNil(result)
// Another call with different URL
_, _ = mock.GetJWKS(context.Background(), "https://other.com/jwks", nil)
// Verify calls were tracked
s.Equal(2, mock.GetJWKSCallCount())
mock.AssertGetJWKSCalled(s.T())
mock.AssertGetJWKSCalledWith(s.T(), "https://example.com/jwks")
mock.AssertGetJWKSCallCount(s.T(), 2)
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheWithError() {
expectedErr := errors.New("network error")
mock := &EnhancedMockJWKCache{
Err: expectedErr,
}
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.Nil(result)
s.Equal(expectedErr, err)
mock.AssertGetJWKSCalled(s.T())
}
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheReset() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
s.Equal(1, mock.GetJWKSCallCount())
mock.Reset()
s.Equal(0, mock.GetJWKSCallCount())
s.Nil(mock.JWKS)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierCallTracking() {
mock := &EnhancedMockTokenVerifier{
Err: nil, // Valid tokens
}
// Verify a token
err := mock.VerifyToken("test-token-1")
s.NoError(err)
// Verify another token
err = mock.VerifyToken("test-token-2")
s.NoError(err)
// Check tracking
s.Equal(2, mock.GetVerifyTokenCallCount())
mock.AssertVerifyTokenCalled(s.T())
mock.AssertVerifyTokenCalledWith(s.T(), "test-token-1")
// Check last call
lastCall := mock.LastCall()
s.NotNil(lastCall)
s.Equal("test-token-2", lastCall.Token)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierWithDynamicFunc() {
callCount := 0
mock := &EnhancedMockTokenVerifier{
VerifyFunc: func(token string) error {
callCount++
if token == "invalid" {
return errors.New("invalid token")
}
return nil
},
}
// Valid token
err := mock.VerifyToken("valid-token")
s.NoError(err)
// Invalid token
err = mock.VerifyToken("invalid")
s.Error(err)
s.Equal(2, callCount)
s.Equal(2, mock.GetVerifyTokenCallCount())
}
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerCallTracking() {
mock := &EnhancedMockTokenExchanger{
ExchangeResponse: &TokenResponse{
AccessToken: "access-token",
RefreshToken: "refresh-token",
ExpiresIn: 3600,
},
RefreshResponse: &TokenResponse{
AccessToken: "new-access-token",
ExpiresIn: 3600,
},
}
// Exchange code
resp, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "auth-code", "https://redirect.com", "verifier")
s.NoError(err)
s.Equal("access-token", resp.AccessToken)
// Refresh token
resp, err = mock.GetNewTokenWithRefreshToken("refresh-token")
s.NoError(err)
s.Equal("new-access-token", resp.AccessToken)
// Revoke token
err = mock.RevokeTokenWithProvider("access-token", "access_token")
s.NoError(err)
// Check tracking
mock.AssertExchangeCalled(s.T())
mock.AssertExchangeCalledWith(s.T(), "authorization_code")
mock.AssertRefreshCalled(s.T())
mock.AssertRevokeCalled(s.T())
s.Equal(1, mock.GetExchangeCallCount())
s.Equal(1, mock.GetRefreshCallCount())
s.Equal(1, mock.GetRevokeCallCount())
// Check last exchange call details
lastExchange := mock.LastExchangeCall()
s.NotNil(lastExchange)
s.Equal("authorization_code", lastExchange.GrantType)
s.Equal("auth-code", lastExchange.CodeOrToken)
s.Equal("https://redirect.com", lastExchange.RedirectURL)
}
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerWithErrors() {
mock := &EnhancedMockTokenExchanger{
ExchangeErr: errors.New("invalid_grant"),
RefreshErr: errors.New("refresh_expired"),
RevokeErr: errors.New("revoke_failed"),
}
_, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "code", "", "")
s.Error(err)
s.Contains(err.Error(), "invalid_grant")
_, err = mock.GetNewTokenWithRefreshToken("token")
s.Error(err)
s.Contains(err.Error(), "refresh_expired")
err = mock.RevokeTokenWithProvider("token", "access_token")
s.Error(err)
s.Contains(err.Error(), "revoke_failed")
}
func (s *EnhancedMocksSuite) TestEnhancedCacheCallTracking() {
mock := NewEnhancedMockCache()
// Set some values
mock.Set("key1", "value1", 5*time.Minute)
mock.Set("key2", "value2", 10*time.Minute)
// Get values
val, found := mock.Get("key1")
s.True(found)
s.Equal("value1", val)
_, found = mock.Get("nonexistent")
s.False(found)
// Delete
mock.Delete("key1")
// Verify tracking
mock.AssertSetCalled(s.T(), "key1")
mock.AssertSetCalled(s.T(), "key2")
mock.AssertGetCalled(s.T(), "key1")
mock.AssertGetCalled(s.T(), "nonexistent")
mock.AssertDeleteCalled(s.T(), "key1")
s.Equal(2, mock.SetCallCount())
s.Equal(2, mock.GetCallCount())
}
func (s *EnhancedMocksSuite) TestEnhancedCacheActualStorage() {
mock := NewEnhancedMockCache()
// The enhanced mock actually stores data
mock.Set("key", "value", time.Hour)
s.Equal(1, mock.Size())
val, found := mock.Get("key")
s.True(found)
s.Equal("value", val)
mock.Delete("key")
s.Equal(0, mock.Size())
_, found = mock.Get("key")
s.False(found)
}
func (s *EnhancedMocksSuite) TestEnhancedCacheClear() {
mock := NewEnhancedMockCache()
mock.Set("key1", "value1", time.Hour)
mock.Set("key2", "value2", time.Hour)
s.Equal(2, mock.Size())
mock.Clear()
s.Equal(0, mock.Size())
}
func (s *EnhancedMocksSuite) TestConcurrentAccess() {
mock := &EnhancedMockJWKCache{
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
}
// Concurrent calls should be safe
done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
done <- true
}()
}
for i := 0; i < 10; i++ {
<-done
}
s.Equal(10, mock.GetJWKSCallCount())
}
func TestEnhancedMocksSuite(t *testing.T) {
suite.Run(t, new(EnhancedMocksSuite))
}
+604
View File
@@ -0,0 +1,604 @@
package traefikoidc
import (
"context"
"crypto"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/stretchr/testify/assert"
)
// EnhancedMockJWKCache is an improved state-based mock with call tracking
type EnhancedMockJWKCache struct {
Err error
JWKS *JWKSet
GetJWKSCalls []JWKSCall
mu sync.RWMutex
getJWKSCallsMu sync.Mutex
CleanupCalls int32
CloseCalls int32
}
// JWKSCall records parameters from a GetJWKS call
type JWKSCall struct {
Timestamp time.Time
URL string
}
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
m.getJWKSCallsMu.Lock()
m.GetJWKSCalls = append(m.GetJWKSCalls, JWKSCall{
URL: jwksURL,
Timestamp: time.Now(),
})
m.getJWKSCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
return m.JWKS, m.Err
}
func (m *EnhancedMockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
jwks, err := m.GetJWKS(ctx, jwksURL, httpClient)
if err != nil {
return nil, err
}
if jwks == nil {
return nil, fmt.Errorf("JWKS is nil")
}
for i := range jwks.Keys {
k := &jwks.Keys[i]
if k.Kid != kid {
continue
}
switch k.Kty {
case "RSA":
return k.ToRSAPublicKey()
case "EC":
return k.ToECDSAPublicKey()
default:
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
}
}
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
}
func (m *EnhancedMockJWKCache) Cleanup() {
atomic.AddInt32(&m.CleanupCalls, 1)
m.mu.Lock()
defer m.mu.Unlock()
m.JWKS = nil
m.Err = nil
}
func (m *EnhancedMockJWKCache) Close() {
atomic.AddInt32(&m.CloseCalls, 1)
}
// Assertion helpers
// AssertGetJWKSCalled verifies GetJWKS was called
func (m *EnhancedMockJWKCache) AssertGetJWKSCalled(t assert.TestingT) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return assert.NotEmpty(t, m.GetJWKSCalls, "GetJWKS should have been called")
}
// AssertGetJWKSCalledWith verifies GetJWKS was called with specific URL
func (m *EnhancedMockJWKCache) AssertGetJWKSCalledWith(t assert.TestingT, expectedURL string) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
for _, call := range m.GetJWKSCalls {
if call.URL == expectedURL {
return true
}
}
return assert.Fail(t, "GetJWKS was not called with URL: "+expectedURL)
}
// AssertGetJWKSCallCount verifies the number of GetJWKS calls
func (m *EnhancedMockJWKCache) AssertGetJWKSCallCount(t assert.TestingT, expected int) bool {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return assert.Equal(t, expected, len(m.GetJWKSCalls), "GetJWKS call count mismatch")
}
// GetJWKSCallCount returns the number of GetJWKS calls
func (m *EnhancedMockJWKCache) GetJWKSCallCount() int {
m.getJWKSCallsMu.Lock()
defer m.getJWKSCallsMu.Unlock()
return len(m.GetJWKSCalls)
}
// Reset clears all state and call tracking
func (m *EnhancedMockJWKCache) Reset() {
m.mu.Lock()
m.JWKS = nil
m.Err = nil
m.mu.Unlock()
m.getJWKSCallsMu.Lock()
m.GetJWKSCalls = nil
m.getJWKSCallsMu.Unlock()
atomic.StoreInt32(&m.CleanupCalls, 0)
atomic.StoreInt32(&m.CloseCalls, 0)
}
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
type EnhancedMockTokenVerifier struct {
Err error
VerifyFunc func(token string) error
VerifyCalls []TokenVerifyCall
mu sync.RWMutex
verifyCallsMu sync.Mutex
}
// TokenVerifyCall records parameters from a VerifyToken call
type TokenVerifyCall struct {
Timestamp time.Time
Result error
Token string
}
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
var result error
m.mu.RLock()
if m.VerifyFunc != nil {
result = m.VerifyFunc(token)
} else {
result = m.Err
}
m.mu.RUnlock()
m.verifyCallsMu.Lock()
m.VerifyCalls = append(m.VerifyCalls, TokenVerifyCall{
Token: token,
Timestamp: time.Now(),
Result: result,
})
m.verifyCallsMu.Unlock()
return result
}
// Assertion helpers
// AssertVerifyTokenCalled verifies VerifyToken was called
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalled(t assert.TestingT) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return assert.NotEmpty(t, m.VerifyCalls, "VerifyToken should have been called")
}
// AssertVerifyTokenCalledWith verifies VerifyToken was called with specific token
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalledWith(t assert.TestingT, expectedToken string) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
for _, call := range m.VerifyCalls {
if call.Token == expectedToken {
return true
}
}
return assert.Fail(t, "VerifyToken was not called with expected token")
}
// AssertVerifyTokenCallCount verifies the number of VerifyToken calls
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCallCount(t assert.TestingT, expected int) bool {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return assert.Equal(t, expected, len(m.VerifyCalls), "VerifyToken call count mismatch")
}
// GetVerifyTokenCallCount returns the number of VerifyToken calls
func (m *EnhancedMockTokenVerifier) GetVerifyTokenCallCount() int {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
return len(m.VerifyCalls)
}
// LastCall returns the most recent VerifyToken call
func (m *EnhancedMockTokenVerifier) LastCall() *TokenVerifyCall {
m.verifyCallsMu.Lock()
defer m.verifyCallsMu.Unlock()
if len(m.VerifyCalls) == 0 {
return nil
}
return &m.VerifyCalls[len(m.VerifyCalls)-1]
}
// Reset clears all state and call tracking
func (m *EnhancedMockTokenVerifier) Reset() {
m.mu.Lock()
m.Err = nil
m.VerifyFunc = nil
m.mu.Unlock()
m.verifyCallsMu.Lock()
m.VerifyCalls = nil
m.verifyCallsMu.Unlock()
}
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
type EnhancedMockTokenExchanger struct {
RefreshErr error
RevokeErr error
ExchangeErr error
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
RefreshResponse *TokenResponse
ExchangeResponse *TokenResponse
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
RevokeTokenFunc func(token, tokenType string) error
ExchangeCalls []ExchangeCall
RefreshCalls []RefreshCall
RevokeCalls []RevokeCall
mu sync.RWMutex
exchangeCallsMu sync.Mutex
refreshCallsMu sync.Mutex
revokeCallsMu sync.Mutex
}
// ExchangeCall records parameters from an ExchangeCodeForToken call
type ExchangeCall struct {
Timestamp time.Time
GrantType string
CodeOrToken string
RedirectURL string
CodeVerifier string
}
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
type RefreshCall struct {
Timestamp time.Time
RefreshToken string
}
// RevokeCall records parameters from a RevokeTokenWithProvider call
type RevokeCall struct {
Timestamp time.Time
Token string
TokenType string
}
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
m.exchangeCallsMu.Lock()
m.ExchangeCalls = append(m.ExchangeCalls, ExchangeCall{
GrantType: grantType,
CodeOrToken: codeOrToken,
RedirectURL: redirectURL,
CodeVerifier: codeVerifier,
Timestamp: time.Now(),
})
m.exchangeCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.ExchangeCodeFunc != nil {
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
}
return m.ExchangeResponse, m.ExchangeErr
}
func (m *EnhancedMockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
m.refreshCallsMu.Lock()
m.RefreshCalls = append(m.RefreshCalls, RefreshCall{
RefreshToken: refreshToken,
Timestamp: time.Now(),
})
m.refreshCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.RefreshTokenFunc != nil {
return m.RefreshTokenFunc(refreshToken)
}
return m.RefreshResponse, m.RefreshErr
}
func (m *EnhancedMockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
m.revokeCallsMu.Lock()
m.RevokeCalls = append(m.RevokeCalls, RevokeCall{
Token: token,
TokenType: tokenType,
Timestamp: time.Now(),
})
m.revokeCallsMu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
if m.RevokeTokenFunc != nil {
return m.RevokeTokenFunc(token, tokenType)
}
return m.RevokeErr
}
// Assertion helpers
// AssertExchangeCalled verifies ExchangeCodeForToken was called
func (m *EnhancedMockTokenExchanger) AssertExchangeCalled(t assert.TestingT) bool {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
return assert.NotEmpty(t, m.ExchangeCalls, "ExchangeCodeForToken should have been called")
}
// AssertExchangeCalledWith verifies ExchangeCodeForToken was called with specific grant type
func (m *EnhancedMockTokenExchanger) AssertExchangeCalledWith(t assert.TestingT, grantType string) bool {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
for _, call := range m.ExchangeCalls {
if call.GrantType == grantType {
return true
}
}
return assert.Fail(t, "ExchangeCodeForToken was not called with grant type: "+grantType)
}
// AssertRefreshCalled verifies GetNewTokenWithRefreshToken was called
func (m *EnhancedMockTokenExchanger) AssertRefreshCalled(t assert.TestingT) bool {
m.refreshCallsMu.Lock()
defer m.refreshCallsMu.Unlock()
return assert.NotEmpty(t, m.RefreshCalls, "GetNewTokenWithRefreshToken should have been called")
}
// AssertRevokeCalled verifies RevokeTokenWithProvider was called
func (m *EnhancedMockTokenExchanger) AssertRevokeCalled(t assert.TestingT) bool {
m.revokeCallsMu.Lock()
defer m.revokeCallsMu.Unlock()
return assert.NotEmpty(t, m.RevokeCalls, "RevokeTokenWithProvider should have been called")
}
// GetExchangeCallCount returns the number of ExchangeCodeForToken calls
func (m *EnhancedMockTokenExchanger) GetExchangeCallCount() int {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
return len(m.ExchangeCalls)
}
// GetRefreshCallCount returns the number of GetNewTokenWithRefreshToken calls
func (m *EnhancedMockTokenExchanger) GetRefreshCallCount() int {
m.refreshCallsMu.Lock()
defer m.refreshCallsMu.Unlock()
return len(m.RefreshCalls)
}
// GetRevokeCallCount returns the number of RevokeTokenWithProvider calls
func (m *EnhancedMockTokenExchanger) GetRevokeCallCount() int {
m.revokeCallsMu.Lock()
defer m.revokeCallsMu.Unlock()
return len(m.RevokeCalls)
}
// LastExchangeCall returns the most recent ExchangeCodeForToken call
func (m *EnhancedMockTokenExchanger) LastExchangeCall() *ExchangeCall {
m.exchangeCallsMu.Lock()
defer m.exchangeCallsMu.Unlock()
if len(m.ExchangeCalls) == 0 {
return nil
}
return &m.ExchangeCalls[len(m.ExchangeCalls)-1]
}
// Reset clears all state and call tracking
func (m *EnhancedMockTokenExchanger) Reset() {
m.mu.Lock()
m.ExchangeResponse = nil
m.ExchangeErr = nil
m.RefreshResponse = nil
m.RefreshErr = nil
m.RevokeErr = nil
m.ExchangeCodeFunc = nil
m.RefreshTokenFunc = nil
m.RevokeTokenFunc = nil
m.mu.Unlock()
m.exchangeCallsMu.Lock()
m.ExchangeCalls = nil
m.exchangeCallsMu.Unlock()
m.refreshCallsMu.Lock()
m.RefreshCalls = nil
m.refreshCallsMu.Unlock()
m.revokeCallsMu.Lock()
m.RevokeCalls = nil
m.revokeCallsMu.Unlock()
}
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
type EnhancedMockCacheInterface struct {
data map[string]cacheEntry
GetCalls []CacheGetCall
SetCalls []CacheSetCall
DeleteCalls []string
maxSize int
mu sync.RWMutex
getCalls sync.Mutex
setCalls sync.Mutex
deleteCalls sync.Mutex
}
type cacheEntry struct {
value any
ttl time.Duration
}
// CacheGetCall records parameters from a Get call
type CacheGetCall struct {
Timestamp time.Time
Key string
Found bool
}
// CacheSetCall records parameters from a Set call
type CacheSetCall struct {
Timestamp time.Time
Value any
Key string
TTL time.Duration
}
// NewEnhancedMockCache creates a new enhanced cache mock
func NewEnhancedMockCache() *EnhancedMockCacheInterface {
return &EnhancedMockCacheInterface{
data: make(map[string]cacheEntry),
maxSize: 1000,
}
}
func (m *EnhancedMockCacheInterface) Set(key string, value any, ttl time.Duration) {
m.setCalls.Lock()
m.SetCalls = append(m.SetCalls, CacheSetCall{
Key: key,
Value: value,
TTL: ttl,
Timestamp: time.Now(),
})
m.setCalls.Unlock()
m.mu.Lock()
m.data[key] = cacheEntry{value: value, ttl: ttl}
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Get(key string) (any, bool) {
m.mu.RLock()
entry, found := m.data[key]
m.mu.RUnlock()
m.getCalls.Lock()
m.GetCalls = append(m.GetCalls, CacheGetCall{
Key: key,
Found: found,
Timestamp: time.Now(),
})
m.getCalls.Unlock()
if found {
return entry.value, true
}
return nil, false
}
func (m *EnhancedMockCacheInterface) Delete(key string) {
m.deleteCalls.Lock()
m.DeleteCalls = append(m.DeleteCalls, key)
m.deleteCalls.Unlock()
m.mu.Lock()
delete(m.data, key)
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) SetMaxSize(size int) {
m.mu.Lock()
m.maxSize = size
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Size() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.data)
}
func (m *EnhancedMockCacheInterface) Clear() {
m.mu.Lock()
m.data = make(map[string]cacheEntry)
m.mu.Unlock()
}
func (m *EnhancedMockCacheInterface) Cleanup() {
// No-op for mock
}
func (m *EnhancedMockCacheInterface) Close() {
// No-op for mock
}
func (m *EnhancedMockCacheInterface) GetStats() map[string]any {
m.mu.RLock()
defer m.mu.RUnlock()
return map[string]any{
"size": len(m.data),
"max_size": m.maxSize,
}
}
// Assertion helpers
// AssertGetCalled verifies Get was called with specific key
func (m *EnhancedMockCacheInterface) AssertGetCalled(t assert.TestingT, key string) bool {
m.getCalls.Lock()
defer m.getCalls.Unlock()
for _, call := range m.GetCalls {
if call.Key == key {
return true
}
}
return assert.Fail(t, "Get was not called with key: "+key)
}
// AssertSetCalled verifies Set was called with specific key
func (m *EnhancedMockCacheInterface) AssertSetCalled(t assert.TestingT, key string) bool {
m.setCalls.Lock()
defer m.setCalls.Unlock()
for _, call := range m.SetCalls {
if call.Key == key {
return true
}
}
return assert.Fail(t, "Set was not called with key: "+key)
}
// AssertDeleteCalled verifies Delete was called with specific key
func (m *EnhancedMockCacheInterface) AssertDeleteCalled(t assert.TestingT, key string) bool {
m.deleteCalls.Lock()
defer m.deleteCalls.Unlock()
for _, k := range m.DeleteCalls {
if k == key {
return true
}
}
return assert.Fail(t, "Delete was not called with key: "+key)
}
// GetCallCount returns the number of Get calls
func (m *EnhancedMockCacheInterface) GetCallCount() int {
m.getCalls.Lock()
defer m.getCalls.Unlock()
return len(m.GetCalls)
}
// SetCallCount returns the number of Set calls
func (m *EnhancedMockCacheInterface) SetCallCount() int {
m.setCalls.Lock()
defer m.setCalls.Unlock()
return len(m.SetCalls)
}
// Reset clears all state and call tracking
func (m *EnhancedMockCacheInterface) Reset() {
m.mu.Lock()
m.data = make(map[string]cacheEntry)
m.mu.Unlock()
m.getCalls.Lock()
m.GetCalls = nil
m.getCalls.Unlock()
m.setCalls.Lock()
m.SetCalls = nil
m.setCalls.Unlock()
m.deleteCalls.Lock()
m.DeleteCalls = nil
m.deleteCalls.Unlock()
}
+152 -40
View File
@@ -2,10 +2,14 @@ package traefikoidc
import (
"context"
"crypto/x509"
"errors"
"fmt"
"io"
"math"
"math/rand/v2"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@@ -123,8 +127,10 @@ func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
}
if metrics["total_requests"].(int64) > 0 {
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
totalReq, _ := metrics["total_requests"].(int64) // Safe to ignore: type assertion with fallback
totalSucc, _ := metrics["total_successes"].(int64) // Safe to ignore: type assertion with fallback
if totalReq > 0 {
successRate := float64(totalSucc) / float64(totalReq)
metrics["success_rate"] = successRate
} else {
metrics["success_rate"] = 1.0
@@ -409,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
}
}
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
// fetching during startup. Uses more aggressive retry settings to handle the race
// condition where Traefik initializes the plugin before routes are fully established,
// or before TLS certificates are properly loaded.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func MetadataFetchRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 10, // More attempts for startup scenarios
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
MaxDelay: 10 * time.Second, // Cap at 10 seconds
BackoffFactor: 1.5, // Gentler backoff for startup
EnableJitter: true, // Prevent thundering herd
RetryableErrors: []string{
"connection refused",
"timeout",
"temporary failure",
"network unreachable",
"EOF",
"certificate",
"x509",
"tls",
},
}
}
// RetryExecutor implements retry logic with exponential backoff and jitter.
// It automatically retries failed operations based on configurable error patterns
// and uses exponential backoff to avoid overwhelming failing services.
@@ -485,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
// isRetryableError checks if an error should trigger a retry
// isRetryableError determines if an error should trigger a retry attempt.
// Checks error message against configured retryable error patterns.
// Also handles startup-specific errors like Traefik default certificate errors
// and EOF errors that occur during service initialization.
func (re *RetryExecutor) isRetryableError(err error) bool {
if err == nil {
return false
}
// Check for Traefik default certificate error (startup race condition)
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
if isTraefikDefaultCertError(err) {
return true
}
// Check for EOF errors (common during startup when services aren't ready)
if isEOFError(err) {
return true
}
// Check for certificate errors (transient during startup)
if isCertificateError(err) {
return true
}
errStr := err.Error()
for _, retryableErr := range re.config.RetryableErrors {
@@ -536,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
delay = float64(re.config.MaxDelay)
}
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
if re.config.EnableJitter {
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
delay += jitter
@@ -592,14 +642,10 @@ func (e *HTTPError) Error() string {
// OIDCError represents OIDC-specific errors with context information.
// It provides structured error reporting for authentication and authorization failures.
type OIDCError struct {
// Code identifies the specific error type
Code string
// Message provides a human-readable description
Message string
// Context contains additional error context (e.g., provider, session details)
Cause error
Context map[string]interface{}
// Cause is the underlying error that caused this error
Cause error
Code string
Message string
}
// Error returns the string representation of the OIDC error.
@@ -619,14 +665,10 @@ func (e *OIDCError) Unwrap() error {
// SessionError represents session-related errors with context.
// Used for session management, validation, and storage errors.
type SessionError struct {
// Operation describes what session operation failed
Cause error
Operation string
// Message provides a human-readable description
Message string
// SessionID identifies the session (if available)
Message string
SessionID string
// Cause is the underlying error that caused this error
Cause error
}
// Error returns the string representation of the session error.
@@ -646,14 +688,10 @@ func (e *SessionError) Unwrap() error {
// TokenError represents token-related errors with validation context.
// Used for JWT validation, token refresh, and token format errors.
type TokenError struct {
// TokenType identifies the type of token (id_token, access_token, refresh_token)
Cause error
TokenType string
// Reason describes why the token is invalid
Reason string
// Message provides a human-readable description
Message string
// Cause is the underlying error that caused this error
Cause error
Reason string
Message string
}
// Error returns the string representation of the token error.
@@ -715,24 +753,15 @@ func NewTokenError(tokenType, reason, message string, cause error) *TokenError {
// It provides fallback mechanisms when primary services are unavailable and monitors
// service health to automatically recover when services become available again.
type GracefulDegradation struct {
// BaseRecoveryMechanism provides common functionality
*BaseRecoveryMechanism
// fallbacks stores service-specific fallback implementations
fallbacks map[string]func() (interface{}, error)
// healthChecks stores service health check functions
healthChecks map[string]func() bool
// degradedServices tracks which services are currently degraded
fallbacks map[string]func() (interface{}, error)
healthChecks map[string]func() bool
degradedServices map[string]time.Time
// config contains graceful degradation configuration
config GracefulDegradationConfig
// mutex protects shared state
mutex sync.RWMutex
// healthCheckTask manages background health checking
healthCheckTask *BackgroundTask
// stopChan signals shutdown
stopChan chan struct{}
// shutdownOnce ensures shutdown happens only once
shutdownOnce sync.Once
healthCheckTask *BackgroundTask
stopChan chan struct{}
config GracefulDegradationConfig
mutex sync.RWMutex
shutdownOnce sync.Once
}
// GracefulDegradationConfig holds configuration for graceful degradation behavior.
@@ -925,7 +954,7 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
gd.mutex.RLock()
defer gd.mutex.RUnlock()
var degraded []string
degraded := make([]string, 0, len(gd.degradedServices))
for serviceName := range gd.degradedServices {
degraded = append(degraded, serviceName)
}
@@ -963,7 +992,7 @@ func (gd *GracefulDegradation) Close() {
// Don't set to nil to avoid race conditions
}
gd.logger.Info("GracefulDegradation shut down successfully")
gd.logger.Debug("GracefulDegradation shut down successfully")
})
}
@@ -1085,3 +1114,86 @@ func containsSubstring(s, substr string) bool {
}
return false
}
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
// certificate during cold-start, before the real certificates are loaded.
// This manifests as an x509.HostnameError where one of the certificate's DNS names
// ends with "traefik.default" (the default Traefik certificate pattern).
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func isTraefikDefaultCertError(err error) bool {
if err == nil {
return false
}
var hostnameErr x509.HostnameError
if errors.As(err, &hostnameErr) {
if hostnameErr.Certificate != nil {
for _, name := range hostnameErr.Certificate.DNSNames {
if strings.HasSuffix(name, "traefik.default") {
return true
}
}
}
}
return false
}
// isEOFError checks if an error is an EOF error, which can occur during
// connection establishment when the remote end closes unexpectedly.
// This is common during service startup when endpoints aren't fully ready.
func isEOFError(err error) bool {
if err == nil {
return false
}
// Check for direct EOF
if errors.Is(err, io.EOF) {
return true
}
// Check for unexpected EOF
if errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
// Check error message for EOF patterns (wrapped errors)
errStr := err.Error()
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
}
// isCertificateError checks if an error is related to TLS certificate validation.
// These errors are often transient during startup when services are still initializing.
func isCertificateError(err error) bool {
if err == nil {
return false
}
// Check for x509 certificate errors
var certInvalidErr x509.CertificateInvalidError
var hostnameErr x509.HostnameError
var unknownAuthErr x509.UnknownAuthorityError
if errors.As(err, &certInvalidErr) ||
errors.As(err, &hostnameErr) ||
errors.As(err, &unknownAuthErr) {
return true
}
// Check error message for certificate patterns
errStr := strings.ToLower(err.Error())
certPatterns := []string{
"certificate",
"x509",
"tls",
"ssl",
}
for _, pattern := range certPatterns {
if strings.Contains(errStr, pattern) {
return true
}
}
return false
}
+29
View File
@@ -0,0 +1,29 @@
package traefikoidc
import "testing"
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
for i := 0; i < b.N; i++ {
DefaultCircuitBreakerConfig()
}
}
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.GetBaseMetrics()
}
}
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
logger := GetSingletonNoOpLogger()
base := NewBaseRecoveryMechanism("test-mechanism", logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
base.RecordRequest()
}
}
File diff suppressed because it is too large Load Diff
+496
View File
@@ -0,0 +1,496 @@
# ============================================================================
# Complete Traefik Configuration Example with TraefikOIDC Plugin + Redis
# ============================================================================
#
# This example shows a complete, production-ready configuration for using
# the TraefikOIDC plugin with Redis caching in a multi-replica deployment.
#
# ============================================================================
# Part 1: Traefik Static Configuration (traefik.yml)
# ============================================================================
# This file configures Traefik itself and enables the plugin.
# Place this in /etc/traefik/traefik.yml or mount it in your container.
---
# Static Configuration
api:
dashboard: true
insecure: false # Set to true only for local development
entryPoints:
web:
address: ":80"
http:
redirections:
entryPoint:
to: websecure
scheme: https
websecure:
address: ":443"
http:
tls:
certResolver: letsencrypt
certificatesResolvers:
letsencrypt:
acme:
email: admin@example.com
storage: /letsencrypt/acme.json
httpChallenge:
entryPoint: web
providers:
file:
filename: /etc/traefik/dynamic.yml
watch: true
# Enable the TraefikOIDC plugin
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.8.0
log:
level: INFO
format: json
accessLog:
format: json
# ============================================================================
# Part 2: Traefik Dynamic Configuration (dynamic.yml)
# ============================================================================
# This file defines your routes, services, and middleware.
# Place this in /etc/traefik/dynamic.yml
---
http:
# -------------------------------------------------------------------------
# Middleware Definitions
# -------------------------------------------------------------------------
middlewares:
# Example 1: Minimal Redis Configuration
# Perfect for getting started quickly
oidc-minimal:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-application-client-id"
clientSecret: "your-client-secret-from-provider"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-secure-64-character-encryption-key-must-be-kept-secret"
# Minimal Redis configuration
redis:
enabled: true
address: "redis:6379"
# Example 2: Production Redis Configuration
# Recommended for production deployments with multiple Traefik replicas
oidc-production:
plugin:
traefikoidc:
# OIDC Provider Configuration
clientID: "prod-client-id"
clientSecret: "prod-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
# ----------------------------------------------------------------
# Optional: switch to RFC 7523 private_key_jwt client auth
# (Entra ID, Okta, Auth0, Keycloak). Replaces clientSecret with a
# signed JWT assertion. See README for details and PEM formats.
# ----------------------------------------------------------------
# clientAuthMethod: "private_key_jwt"
# clientAssertionKeyPath: "/etc/traefik/oidc/client-key.pem"
# clientAssertionKeyID: "prod-key-2026"
# clientAssertionAlg: "RS256" # or PS256/384/512, ES256/384/512
# Session Configuration
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
sessionMaxAge: 28800 # 8 hours
# Security Settings
forceHTTPS: true
strictAudienceValidation: true
# Redis Configuration for Multi-Replica Deployment
redis:
enabled: true
address: "redis-master.redis-namespace.svc.cluster.local:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
db: 0
keyPrefix: "traefikoidc:prod:"
# Cache Strategy
cacheMode: "hybrid" # Fast local cache + shared Redis
# Connection Pooling
poolSize: 20
connectTimeout: 5
readTimeout: 3
writeTimeout: 3
# Resilience Features
enableCircuitBreaker: true
circuitBreakerThreshold: 5
circuitBreakerTimeout: 60
enableHealthCheck: true
healthCheckInterval: 30
# Example 3: Redis with TLS (for production security)
oidc-secure:
plugin:
traefikoidc:
clientID: "secure-client-id"
clientSecret: "secure-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "secure-64-character-encryption-key-for-production-use-only"
redis:
enabled: true
address: "redis.example.com:6380"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
enableTLS: true
tlsSkipVerify: false # Verify certificates in production
cacheMode: "redis"
# Example 4: Hybrid Mode (Best Performance + Consistency)
# Local cache for hot data, Redis for consistency across replicas
oidc-hybrid:
plugin:
traefikoidc:
clientID: "app-client-id"
clientSecret: "app-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "hybrid-mode-encryption-key-64-characters-long-and-secure"
redis:
enabled: true
address: "redis:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
cacheMode: "hybrid"
# Hybrid mode L1 cache settings
hybridL1Size: 1000 # Number of items in local cache
hybridL1MemoryMB: 20 # MB of memory for local cache
# -------------------------------------------------------------------------
# Router Definitions
# -------------------------------------------------------------------------
routers:
# Protected application using OIDC authentication
my-app:
rule: "Host(`app.example.com`)"
entryPoints:
- websecure
middlewares:
- oidc-production # Use the OIDC middleware
service: my-app-service
tls:
certResolver: letsencrypt
# Another app with minimal OIDC config
simple-app:
rule: "Host(`simple.example.com`)"
entryPoints:
- websecure
middlewares:
- oidc-minimal
service: simple-app-service
tls:
certResolver: letsencrypt
# -------------------------------------------------------------------------
# Service Definitions
# -------------------------------------------------------------------------
services:
my-app-service:
loadBalancer:
servers:
- url: "http://my-app:8080"
healthCheck:
path: /health
interval: 30s
timeout: 5s
simple-app-service:
loadBalancer:
servers:
- url: "http://simple-app:3000"
# ============================================================================
# Part 3: Docker Compose Example
# ============================================================================
---
# docker-compose.yml
version: '3.8'
services:
# Redis service for shared caching
redis:
image: redis:7-alpine
command: redis-server --requirepass yourredispassword --maxmemory 256mb --maxmemory-policy allkeys-lru
ports:
- "6379:6379"
volumes:
- redis-data:/data
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 10s
timeout: 3s
retries: 5
networks:
- traefik-network
# Traefik with TraefikOIDC plugin
traefik:
image: traefik:v3.2
command:
- "--api.dashboard=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--providers.file.filename=/etc/traefik/dynamic.yml"
- "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443"
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
- "--experimental.plugins.traefikoidc.version=v0.8.0"
ports:
- "80:80"
- "443:443"
- "8080:8080" # Dashboard
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./traefik-dynamic.yml:/etc/traefik/dynamic.yml:ro
- ./letsencrypt:/letsencrypt
depends_on:
- redis
networks:
- traefik-network
# Your application
my-app:
image: my-app:latest
labels:
- "traefik.enable=true"
- "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
- "traefik.http.routers.my-app.entrypoints=websecure"
- "traefik.http.routers.my-app.tls.certresolver=letsencrypt"
# OIDC Middleware Configuration with Redis (using labels)
- "traefik.http.routers.my-app.middlewares=my-oidc@docker"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-client-secret"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-character-encryption-key-here"
# Redis configuration
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=yourredispassword"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
networks:
- traefik-network
deploy:
replicas: 3 # Multiple replicas sharing Redis cache
volumes:
redis-data:
networks:
traefik-network:
driver: bridge
# ============================================================================
# Part 4: Kubernetes Example
# ============================================================================
---
# kubernetes-example.yaml
# Redis Deployment
apiVersion: apps/v1
kind: Deployment
metadata:
name: redis
namespace: traefik
spec:
replicas: 1
selector:
matchLabels:
app: redis
template:
metadata:
labels:
app: redis
spec:
containers:
- name: redis
image: redis:7-alpine
args:
- redis-server
- --requirepass
- $(REDIS_PASSWORD)
- --maxmemory
- 512mb
- --maxmemory-policy
- allkeys-lru
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: redis-secret
key: password
ports:
- containerPort: 6379
resources:
requests:
memory: "256Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
---
# Redis Service
apiVersion: v1
kind: Service
metadata:
name: redis
namespace: traefik
spec:
selector:
app: redis
ports:
- port: 6379
targetPort: 6379
---
# Redis Secret
apiVersion: v1
kind: Secret
metadata:
name: redis-secret
namespace: traefik
type: Opaque
stringData:
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
---
# OIDC Middleware with Redis
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
namespace: traefik
spec:
plugin:
traefikoidc:
# OIDC Configuration
clientID: "kubernetes-client-id"
clientSecret: "kubernetes-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "kubernetes-64-character-session-encryption-key-keep-secret"
# Redis Configuration
redis:
enabled: true
address: "redis.traefik.svc.cluster.local:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
db: 0
keyPrefix: "traefikoidc:k8s:"
cacheMode: "hybrid"
poolSize: 20
enableCircuitBreaker: true
enableHealthCheck: true
---
# IngressRoute using the middleware
apiVersion: traefik.io/v1alpha1
kind: IngressRoute
metadata:
name: my-app
namespace: default
spec:
entryPoints:
- websecure
routes:
- match: Host(`app.example.com`)
kind: Rule
middlewares:
- name: oidc-auth
namespace: traefik
services:
- name: my-app
port: 80
tls:
certResolver: letsencrypt
# ============================================================================
# Part 5: Environment Variables (Optional Fallback)
# ============================================================================
# If you prefer environment variables as fallback (not recommended for production),
# you can set these. NOTE: Plugin configuration takes precedence!
# Docker Compose env file (.env)
---
# OIDC Configuration
OIDC_CLIENT_ID=your-client-id
OIDC_CLIENT_SECRET=your-client-secret
OIDC_PROVIDER_URL=https://auth.example.com
# Redis Configuration (fallback)
REDIS_ENABLED=true
REDIS_ADDRESS=redis:6379
REDIS_PASSWORD=yourredispassword
REDIS_DB=0
REDIS_KEY_PREFIX=traefikoidc:
REDIS_CACHE_MODE=hybrid
REDIS_POOL_SIZE=20
REDIS_ENABLE_CIRCUIT_BREAKER=true
REDIS_ENABLE_HEALTH_CHECK=true
# ============================================================================
# Configuration Cheat Sheet
# ============================================================================
# Minimal Setup (Quick Start):
# redis:
# enabled: true
# address: "redis:6379"
# Production Setup (Recommended):
# redis:
# enabled: true
# address: "redis-master:6379"
# password: "strong-password"
# cacheMode: "hybrid"
# enableCircuitBreaker: true
# enableHealthCheck: true
# High Security Setup:
# redis:
# enabled: true
# address: "redis.example.com:6380"
# password: "strong-password"
# enableTLS: true
# tlsSkipVerify: false
# cacheMode: "redis"
# Cache Modes:
# - "memory": Local cache only (default, no Redis needed)
# - "redis": Redis only (consistent, shared across replicas)
# - "hybrid": Local L1 + Redis L2 (best performance + consistency)
+15
View File
@@ -0,0 +1,15 @@
# Minimal oidcgate config. See docs/OIDCGATE.md for full reference.
listen: ":8080"
authPath: "/oauth2/auth"
startPath: "/oauth2/start"
providerURL: "https://accounts.google.com"
clientID: "REPLACE_ME.apps.googleusercontent.com"
clientSecret: "REPLACE_ME" # OR set OIDCGATE_CLIENT_SECRET
sessionEncryptionKey: "REPLACE_WITH_64_HEX_BYTES" # OR OIDCGATE_SESSION_ENCRYPTION_KEY
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
postLogoutRedirectURI: "/"
# allowedUserDomains: [company.com]
# excludedURLs: [/health, /metrics]
+149
View File
@@ -0,0 +1,149 @@
# Example Traefik configuration for TraefikOIDC plugin with Redis caching
# This example shows how to configure Redis through Traefik's dynamic configuration
# Static configuration (traefik.yml)
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.8.0
# Dynamic configuration (dynamic.yml or labels)
http:
middlewares:
# Example 1: Basic Redis configuration
oidc-redis-basic:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis configuration
redis:
enabled: true
address: "redis:6379"
# password: "your-redis-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
# Example 2: Redis with resilience features
oidc-redis-resilient:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis with full resilience configuration
redis:
enabled: true
address: "redis:6379"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder - use your actual password
db: 1
keyPrefix: "myapp:"
poolSize: 20
connectTimeout: 10
readTimeout: 5
writeTimeout: 5
cacheMode: "redis" # Options: "redis", "hybrid", "memory"
# Circuit breaker settings
enableCircuitBreaker: true
circuitBreakerThreshold: 5
circuitBreakerTimeout: 60
# Health check settings
enableHealthCheck: true
healthCheckInterval: 30
# Example 3: Redis with TLS
oidc-redis-tls:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis with TLS configuration
redis:
enabled: true
address: "redis.example.com:6380"
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder
enableTLS: true
tlsSkipVerify: false # Set to true only for testing
cacheMode: "redis"
routers:
my-app:
rule: "Host(`app.example.com`)"
middlewares:
- oidc-redis-basic
service: my-app-service
services:
my-app-service:
loadBalancer:
servers:
- url: "http://localhost:8080"
# Docker Compose labels example
# version: '3.8'
# services:
# traefik:
# image: traefik:v3.0
# # ... other config ...
#
# my-app:
# image: my-app:latest
# labels:
# - "traefik.enable=true"
# - "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
# - "traefik.http.routers.my-app.middlewares=my-oidc"
# # OIDC middleware configuration with Redis
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-secret"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
# # Redis configuration via labels
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=redis-password"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=redis"
#
# redis:
# image: redis:7-alpine
# command: redis-server --requirepass redis-password
# # ... other config ...
# Environment variable fallback (optional)
# If Redis configuration is not provided in Traefik config, these environment variables
# can be used as a fallback (but Traefik config takes precedence):
#
# REDIS_ENABLED=true
# REDIS_ADDRESS=redis:6379
# REDIS_PASSWORD=secret
# REDIS_DB=0
# REDIS_KEY_PREFIX=traefikoidc:
# REDIS_CACHE_MODE=redis
# REDIS_POOL_SIZE=10
# REDIS_CONNECT_TIMEOUT=5
# REDIS_READ_TIMEOUT=3
# REDIS_WRITE_TIMEOUT=3
# REDIS_ENABLE_TLS=false
# REDIS_TLS_SKIP_VERIFY=false
# REDIS_ENABLE_CIRCUIT_BREAKER=true
# REDIS_CIRCUIT_BREAKER_THRESHOLD=5
# REDIS_CIRCUIT_BREAKER_TIMEOUT=60
# REDIS_ENABLE_HEALTH_CHECK=true
# REDIS_HEALTH_CHECK_INTERVAL=30
-797
View File
@@ -1,797 +0,0 @@
package features
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"text/template"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Mock types for testing
type TemplatedHeader struct {
Name string `json:"name"`
Value string `json:"value"`
}
type MockConfig struct {
ProviderURL string `json:"providerURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
CallbackURL string `json:"callbackURL"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
Headers []TemplatedHeader `json:"headers"`
}
// TestTemplateHeaderFeatures consolidates all template header-related tests
func TestTemplateHeaderFeatures(t *testing.T) {
t.Run("Issue55_TemplateExecutionWithWrongTypes", testIssue55TemplateExecutionWithWrongTypes)
t.Run("Template_Parsing_Validation", testTemplateParsingValidation)
t.Run("Middleware_Header_Templating", testMiddlewareHeaderTemplating)
t.Run("JSON_Config_Parsing", testJSONConfigParsing)
t.Run("Template_Double_Processing", testTemplateDoubleProcessing)
t.Run("Template_Execution_Context", testTemplateExecutionContext)
t.Run("Template_Integration_With_Plugin", testTemplateIntegrationWithPlugin)
t.Run("Template_Syntax_Validation", testTemplateSyntaxValidation)
t.Run("Missing_Field_Handling", testMissingFieldHandling)
t.Run("Complex_Template_Expressions", testComplexTemplateExpressions)
t.Run("Traefik_Configuration_Parsing", testTraefikConfigurationParsing)
}
// testIssue55TemplateExecutionWithWrongTypes tests what happens when templates
// receive wrong data types during execution - reproduces GitHub issue #55
func testIssue55TemplateExecutionWithWrongTypes(t *testing.T) {
testCases := []struct {
name string
templateText string
templateData interface{}
errorContains string
expectError bool
}{
{
name: "correct map data",
templateText: "Bearer {{.AccessToken}}",
templateData: map[string]interface{}{
"AccessToken": "valid-token",
},
expectError: false,
},
{
name: "boolean as root context - reproduces issue #55",
templateText: "Bearer {{.AccessToken}}",
templateData: true,
expectError: true,
errorContains: "can't evaluate field AccessToken in type bool",
},
{
name: "string as root context",
templateText: "Bearer {{.AccessToken}}",
templateData: "just a string",
expectError: true,
errorContains: "can't evaluate field AccessToken in type string",
},
{
name: "nested claims access with correct data",
templateText: "User: {{.Claims.email}}",
templateData: map[string]interface{}{
"Claims": map[string]interface{}{
"email": "user@example.com",
},
},
expectError: false,
},
{
name: "nested claims with wrong structure",
templateText: "User: {{.Claims.email}}",
templateData: map[string]interface{}{
"Claims": "not a map",
},
expectError: true,
errorContains: "can't evaluate field email in type",
},
{
name: "complex nested structure",
templateText: "{{.Claims.sub}} - {{.Claims.groups}} - {{.AccessToken}}",
templateData: map[string]interface{}{
"AccessToken": "token123",
"Claims": map[string]interface{}{
"sub": "user-id",
"groups": "admin,users",
},
},
expectError: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.templateData)
if tc.expectError {
require.Error(t, err)
if tc.errorContains != "" {
assert.Contains(t, err.Error(), tc.errorContains)
}
} else {
require.NoError(t, err)
}
})
}
}
// testTemplateParsingValidation ensures templates are parsed correctly
func testTemplateParsingValidation(t *testing.T) {
testCases := []struct {
name string
headerTemplates []TemplatedHeader
shouldError bool
}{
{
name: "valid bearer token template",
headerTemplates: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
shouldError: false,
},
{
name: "multiple valid templates",
headerTemplates: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
},
shouldError: false,
},
{
name: "template with conditional logic",
headerTemplates: []TemplatedHeader{
{Name: "X-Auth-Info", Value: "{{if .AccessToken}}Bearer {{.AccessToken}}{{else}}No Token{{end}}"},
},
shouldError: false,
},
{
name: "invalid template syntax",
headerTemplates: []TemplatedHeader{
{Name: "Bad-Template", Value: "{{.AccessToken"},
},
shouldError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, header := range tc.headerTemplates {
_, err := template.New(header.Name).Parse(header.Value)
if tc.shouldError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
})
}
}
// testMiddlewareHeaderTemplating simulates the actual middleware flow
func testMiddlewareHeaderTemplating(t *testing.T) {
testCases := []struct {
name string
headers []TemplatedHeader
accessToken string
idToken string
claims map[string]interface{}
expectedValues map[string]string
}{
{
name: "authorization header with access token",
headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
accessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
expectedValues: map[string]string{
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
},
},
{
name: "multiple headers with claims",
headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Groups", Value: "{{.Claims.groups}}"},
{Name: "X-Auth-Token", Value: "{{.AccessToken}}"},
},
accessToken: "token123",
claims: map[string]interface{}{
"email": "user@example.com",
"groups": "admin,developers",
},
expectedValues: map[string]string{
"X-User-Email": "user@example.com",
"X-User-Groups": "admin,developers",
"X-Auth-Token": "token123",
},
},
{
name: "complex template expressions",
headers: []TemplatedHeader{
{Name: "X-User-Info", Value: "{{.Claims.sub}} ({{.Claims.email}})"},
{Name: "X-Auth-Header", Value: "Bearer {{.AccessToken}} | ID: {{.IDToken}}"},
},
accessToken: "access-token",
idToken: "id-token",
claims: map[string]interface{}{
"sub": "user-12345",
"email": "john@example.com",
},
expectedValues: map[string]string{
"X-User-Info": "user-12345 (john@example.com)",
"X-Auth-Header": "Bearer access-token | ID: id-token",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Parse all templates
headerTemplates := make(map[string]*template.Template)
for _, header := range tc.headers {
tmpl, err := template.New(header.Name).Parse(header.Value)
require.NoError(t, err)
headerTemplates[header.Name] = tmpl
}
// Create template data
templateData := map[string]interface{}{
"AccessToken": tc.accessToken,
"IDToken": tc.idToken,
"Claims": tc.claims,
}
// Create a test request
req := httptest.NewRequest("GET", "/test", nil)
// Execute templates and set headers
for headerName, tmpl := range headerTemplates {
var buf bytes.Buffer
err := tmpl.Execute(&buf, templateData)
require.NoError(t, err)
req.Header.Set(headerName, buf.String())
}
// Verify all expected headers are set correctly
for headerName, expectedValue := range tc.expectedValues {
actualValue := req.Header.Get(headerName)
assert.Equal(t, expectedValue, actualValue)
}
})
}
}
// testJSONConfigParsing tests that JSON configuration is properly parsed
func testJSONConfigParsing(t *testing.T) {
testCases := []struct {
name string
jsonConfig string
expectedError bool
description string
}{
{
name: "valid JSON configuration",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": "Bearer {{.AccessToken}}"
}
]
}`,
expectedError: false,
description: "Properly formatted JSON with string values",
},
{
name: "JSON with boolean value",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": true
}
]
}`,
expectedError: true,
description: "Boolean value instead of string template",
},
{
name: "JSON with number value",
jsonConfig: `{
"headers": [
{
"name": "Authorization",
"value": 123
}
]
}`,
expectedError: true,
description: "Number value instead of string template",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var config struct {
Headers []TemplatedHeader `json:"headers"`
}
err := json.Unmarshal([]byte(tc.jsonConfig), &config)
if tc.expectedError {
require.Error(t, err, tc.description)
} else {
require.NoError(t, err, tc.description)
}
})
}
}
// testTemplateDoubleProcessing tests if template strings are being double-processed
func testTemplateDoubleProcessing(t *testing.T) {
// Simulate how Traefik passes config to the plugin
config := &MockConfig{
Headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
},
}
// Verify that template strings are still raw (not processed)
assert.Equal(t, "{{.Claims.email}}", config.Headers[0].Value)
assert.Equal(t, "{{.Claims.internal_role}}", config.Headers[1].Value)
// Simulate template parsing during initialization
headerTemplates := make(map[string]*template.Template)
funcMap := template.FuncMap{
"default": func(defaultVal interface{}, val interface{}) interface{} {
if val == nil || val == "" || val == "<no value>" {
return defaultVal
}
return val
},
"get": func(m interface{}, key string) interface{} {
if mapVal, ok := m.(map[string]interface{}); ok {
if val, exists := mapVal[key]; exists {
return val
}
}
return ""
},
}
for _, header := range config.Headers {
tmpl := template.New(header.Name).Funcs(funcMap).Option("missingkey=zero")
parsedTmpl, err := tmpl.Parse(header.Value)
require.NoError(t, err)
headerTemplates[header.Name] = parsedTmpl
}
// Test execution with actual claims
claims := map[string]interface{}{
"email": "user@example.com",
// Note: internal_role is missing
}
templateData := map[string]interface{}{
"Claims": claims,
}
// Execute templates
for headerName, tmpl := range headerTemplates {
var buf bytes.Buffer
err := tmpl.Execute(&buf, templateData)
require.NoError(t, err)
result := buf.String()
if headerName == "X-User-Email" {
assert.Equal(t, "user@example.com", result)
} else if headerName == "X-User-Role" {
// With missingkey=zero, missing fields return "<no value>"
assert.Equal(t, "<no value>", result)
}
}
}
// testTemplateExecutionContext tests the specific template data context
func testTemplateExecutionContext(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expectedValue string
}{
{
name: "Access and ID token distinction",
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"IDToken": "id-token-value",
"Claims": map[string]interface{}{},
},
expectedValue: "Access: access-token-value ID: id-token-value",
},
{
name: "Combining tokens and claims",
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
data: map[string]interface{}{
"AccessToken": "access-token",
"IDToken": "id-token",
"Claims": map[string]interface{}{
"sub": "user123",
},
},
expectedValue: "User: user123 Token: access-token",
},
{
name: "Custom non-standard claims",
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
data: map[string]interface{}{
"AccessToken": "access-token-value",
"Claims": map[string]interface{}{
"role": "admin",
"permissions": "read:all,write:own",
},
},
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
},
{
name: "Deeply nested custom claims",
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"app_metadata": map[string]interface{}{
"organization": map[string]interface{}{
"name": "acme-corp",
},
"team": "platform",
},
},
},
expectedValue: "X-Organization: acme-corp, X-Team: platform",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expectedValue, buf.String())
})
}
}
// testTemplateIntegrationWithPlugin tests template processing in the actual plugin
func testTemplateIntegrationWithPlugin(t *testing.T) {
// Test template integration using mock plugin components
// Set up test OIDC server
var testServerURL string
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": testServerURL,
"authorization_endpoint": testServerURL + "/auth",
"token_endpoint": testServerURL + "/token",
"jwks_uri": testServerURL + "/jwks",
"userinfo_endpoint": testServerURL + "/userinfo",
})
case "/jwks":
json.NewEncoder(w).Encode(map[string]interface{}{
"keys": []interface{}{},
})
default:
http.NotFound(w, r)
}
}))
defer testServer.Close()
testServerURL = testServer.URL
// Create config with templates that reference potentially missing fields
config := &MockConfig{
ProviderURL: testServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-characters",
Headers: []TemplatedHeader{
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
},
}
// Initialize plugin would be done here
ctx := context.Background()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Test would create plugin handler here
_ = ctx
_ = next
_ = config
}
// testTemplateSyntaxValidation tests that template syntax is properly validated
func testTemplateSyntaxValidation(t *testing.T) {
validTemplates := []string{
"{{.Claims.email}}",
"{{.Claims.internal_role}}",
"{{.AccessToken}}",
"{{.IdToken}}",
"{{.RefreshToken}}",
}
for _, tmplStr := range validTemplates {
err := validateTemplateSecure(tmplStr)
assert.NoError(t, err, "Template should be valid: %s", tmplStr)
}
// Test invalid templates
invalidTemplates := []struct {
template string
reason string
}{
{"{{call .SomeFunc}}", "function calls not allowed"},
{"{{range .Items}}{{.}}{{end}}", "range not allowed"},
{"{{with .Data}}{{.Field}}{{end}}", "with statements blocked"},
{"{{index .Array 0}}", "index access blocked"},
{"{{printf \"%s\" .Data}}", "printf blocked"},
}
for _, tc := range invalidTemplates {
err := validateTemplateSecure(tc.template)
assert.Error(t, err, "Template should be invalid: %s (%s)", tc.template, tc.reason)
assert.Contains(t, strings.ToLower(err.Error()), "dangerous")
}
// Test safe custom functions
safeTemplates := []string{
"{{get .Claims \"internal_role\"}}",
"{{default \"guest\" .Claims.role}}",
}
for _, tmplStr := range safeTemplates {
err := validateTemplateSecure(tmplStr)
assert.NoError(t, err, "Safe custom functions should be allowed: %s", tmplStr)
}
}
// Mock validation function for template security
func validateTemplateSecure(templateStr string) error {
// List of potentially dangerous template actions
dangerousFunctions := []string{
"call", "range", "with", "index", "printf", "println", "print",
"js", "html", "urlquery", "base64", "exec",
}
for _, dangerous := range dangerousFunctions {
if strings.Contains(templateStr, dangerous) {
return fmt.Errorf("dangerous template function detected: %s", dangerous)
}
}
// Define safe custom functions
funcMap := template.FuncMap{
"get": func(data map[string]interface{}, key string) interface{} {
return data[key]
},
"default": func(defaultVal interface{}, val interface{}) interface{} {
if val == nil || val == "" {
return defaultVal
}
return val
},
}
// Try to parse the template with custom functions to check for syntax errors
_, err := template.New("test").Funcs(funcMap).Parse(templateStr)
return err
}
// testMissingFieldHandling tests handling of missing fields in templates
func testMissingFieldHandling(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expected string
}{
{
name: "missing claim field",
templateText: "{{.Claims.missing}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{},
},
expected: "<no value>",
},
{
name: "missing nested field",
templateText: "{{.Claims.user.missing}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"user": map[string]interface{}{},
},
},
expected: "<no value>",
},
{
name: "missing entire path",
templateText: "{{.Missing.Path.Field}}",
data: map[string]interface{}{},
expected: "<no value>",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expected, buf.String())
})
}
}
// testComplexTemplateExpressions tests complex template expressions
func testComplexTemplateExpressions(t *testing.T) {
testCases := []struct {
name string
templateText string
data map[string]interface{}
expected string
}{
{
name: "conditional template",
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"admin": true,
},
},
expected: "Admin User",
},
{
name: "multiple claims concatenation",
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"firstName": "John",
"lastName": "Doe",
"email": "john.doe@example.com",
},
},
expected: "John Doe <john.doe@example.com>",
},
{
name: "array access",
templateText: "{{index .Claims.roles 0}}",
data: map[string]interface{}{
"Claims": map[string]interface{}{
"roles": []string{"admin", "user"},
},
},
expected: "admin",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := template.New("test").Parse(tc.templateText)
require.NoError(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tc.data)
require.NoError(t, err)
assert.Equal(t, tc.expected, buf.String())
})
}
}
// testTraefikConfigurationParsing tests various ways Traefik might pass configuration
func testTraefikConfigurationParsing(t *testing.T) {
testCases := []struct {
name string
config *MockConfig
expectError bool
description string
}{
{
name: "valid configuration with templated headers",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
},
},
expectError: false,
description: "Standard configuration should work",
},
{
name: "configuration with multiple headers",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
},
},
expectError: false,
description: "Multiple headers should work",
},
{
name: "empty headers configuration",
config: &MockConfig{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
CallbackURL: "/oauth2/callback",
Headers: []TemplatedHeader{},
},
expectError: false,
description: "Empty headers should not cause issues",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a simple next handler
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Try to create the middleware would be done here
ctx := context.Background()
// Test would create middleware handler here
_ = ctx
_ = next
_ = tc.config
// For now, we just validate the configuration is well-formed
if !tc.expectError {
require.NotNil(t, tc.config, tc.description)
require.NotEmpty(t, tc.config.ClientID, tc.description)
}
})
}
}
+37
View File
@@ -0,0 +1,37 @@
package traefikoidc
import (
"net/http"
"net/url"
"strings"
)
// originalRequestURI returns the request URI that should be used as the
// post-login redirect target. When TrustForwardedURI is enabled and the
// X-Forwarded-Uri header carries a safe same-origin path, that header
// wins. Otherwise (or if the header is missing/unsafe), falls back to
// req.URL.RequestURI() — the path the request reached the proxy with.
//
// "Safe" means: starts with "/", does NOT start with "//" (protocol-relative
// URLs can change host), and has no scheme or host after parsing. This
// prevents an attacker-controllable header from triggering an open redirect
// via http.Redirect later in the auth flow.
func (t *TraefikOidc) originalRequestURI(req *http.Request) string {
if t.trustForwardedURI {
if v := req.Header.Get("X-Forwarded-Uri"); v != "" && isSafeRedirectTarget(v) {
return v
}
}
return req.URL.RequestURI()
}
func isSafeRedirectTarget(v string) bool {
if !strings.HasPrefix(v, "/") || strings.HasPrefix(v, "//") {
return false
}
u, err := url.Parse(v)
if err != nil {
return false
}
return u.Host == "" && u.Scheme == ""
}
+69
View File
@@ -0,0 +1,69 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestOriginalRequestURI_DefaultOff(t *testing.T) {
tr := &TraefikOidc{trustForwardedURI: false}
req := httptest.NewRequest(http.MethodGet, "/protected?x=1", nil)
req.Header.Set("X-Forwarded-Uri", "/spoofed")
if got := tr.originalRequestURI(req); got != "/protected?x=1" {
t.Fatalf("default-off: want /protected?x=1, got %q", got)
}
}
func TestOriginalRequestURI_TrustEnabled(t *testing.T) {
tr := &TraefikOidc{trustForwardedURI: true}
req := httptest.NewRequest(http.MethodGet, "/protected?x=1", nil)
req.Header.Set("X-Forwarded-Uri", "/real?y=2")
if got := tr.originalRequestURI(req); got != "/real?y=2" {
t.Fatalf("trust-on with header: want /real?y=2, got %q", got)
}
}
func TestOriginalRequestURI_TrustEnabledNoHeader(t *testing.T) {
tr := &TraefikOidc{trustForwardedURI: true}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
if got := tr.originalRequestURI(req); got != "/protected" {
t.Fatalf("trust-on no header: want /protected, got %q", got)
}
}
func TestOriginalRequestURI_RejectsAbsoluteURL(t *testing.T) {
tr := &TraefikOidc{trustForwardedURI: true}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("X-Forwarded-Uri", "https://evil.example/phish")
if got := tr.originalRequestURI(req); got != "/protected" {
t.Fatalf("absolute URL must be rejected, want /protected fallback, got %q", got)
}
}
func TestOriginalRequestURI_RejectsProtocolRelative(t *testing.T) {
tr := &TraefikOidc{trustForwardedURI: true}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("X-Forwarded-Uri", "//evil.example/phish")
if got := tr.originalRequestURI(req); got != "/protected" {
t.Fatalf("protocol-relative URL must be rejected, want /protected fallback, got %q", got)
}
}
func TestOriginalRequestURI_AcceptsSafePathWithQuery(t *testing.T) {
tr := &TraefikOidc{trustForwardedURI: true}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("X-Forwarded-Uri", "/safe?x=1&y=2")
if got := tr.originalRequestURI(req); got != "/safe?x=1&y=2" {
t.Fatalf("safe path with query must be accepted, got %q", got)
}
}
func TestOriginalRequestURI_RejectsBareHostnameNoSlash(t *testing.T) {
tr := &TraefikOidc{trustForwardedURI: true}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("X-Forwarded-Uri", "evil.example/phish")
if got := tr.originalRequestURI(req); got != "/protected" {
t.Fatalf("non-/ prefix must be rejected, got %q", got)
}
}
+8 -3
View File
@@ -3,15 +3,20 @@ module github.com/lukaszraczylo/traefikoidc
go 1.24.0
require (
github.com/google/uuid v1.6.0
github.com/alicebob/miniredis/v2 v2.35.0
github.com/gorilla/sessions v1.3.0
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.13.0
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
)
+18 -4
View File
@@ -1,19 +1,33 @@
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+7 -7
View File
@@ -10,16 +10,16 @@ import (
type GoroutineManager struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
goroutines map[string]*managedGoroutine
logger *Logger
wg sync.WaitGroup
mu sync.RWMutex
}
type managedGoroutine struct {
name string
cancel context.CancelFunc
startTime time.Time
cancel context.CancelFunc
name string
running bool
}
@@ -86,7 +86,7 @@ func (m *GoroutineManager) StartPeriodicTask(name string, interval time.Duration
for {
select {
case <-ctx.Done():
m.logger.Debugf("Periodic task %s cancelled", name)
m.logger.Debugf("Periodic task %s canceled", name)
return
case <-ticker.C:
task()
@@ -149,10 +149,10 @@ func (m *GoroutineManager) GetStatus() map[string]GoroutineStatus {
// GoroutineStatus represents the status of a managed goroutine
type GoroutineStatus struct {
Name string
Running bool
StartTime time.Time
Name string
Runtime time.Duration
Running bool
}
// ErrShutdownTimeout is returned when shutdown times out
+625
View File
@@ -0,0 +1,625 @@
package traefikoidc
import (
"context"
"sync/atomic"
"testing"
"time"
)
// Test GoroutineManager Creation
func TestNewGoroutineManager(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
if gm == nil {
t.Fatal("Expected non-nil goroutine manager")
}
if gm.ctx == nil {
t.Error("Expected context to be initialized")
}
if gm.cancel == nil {
t.Error("Expected cancel function to be initialized")
}
if gm.goroutines == nil {
t.Error("Expected goroutines map to be initialized")
}
if gm.logger != logger {
t.Error("Expected logger to be set")
}
}
// Test Starting Goroutines
func TestStartGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("test-goroutine", func(ctx context.Context) {
executed.Store(true)
})
// Give goroutine time to execute
time.Sleep(50 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute")
}
status := gm.GetStatus()
if len(status) != 1 {
t.Errorf("Expected 1 goroutine in status, got %d", len(status))
}
if _, exists := status["test-goroutine"]; !exists {
t.Error("Expected goroutine 'test-goroutine' in status")
}
}
func TestStartGoroutineDuplicate(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
// Start a long-running goroutine
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
counter.Add(1)
<-ctx.Done()
})
// Give first goroutine time to start
time.Sleep(50 * time.Millisecond)
// Try to start another with same name (should be skipped)
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
counter.Add(1)
})
time.Sleep(50 * time.Millisecond)
// Should only have executed once
if counter.Load() != 1 {
t.Errorf("Expected counter to be 1 (duplicate should be skipped), got %d", counter.Load())
}
}
func TestStartGoroutineContextCancellation(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
started := atomic.Bool{}
canceled := atomic.Bool{}
gm.StartGoroutine("cancel-test", func(ctx context.Context) {
started.Store(true)
<-ctx.Done()
canceled.Store(true)
})
// Wait for goroutine to start
time.Sleep(50 * time.Millisecond)
if !started.Load() {
t.Error("Expected goroutine to start")
}
// Stop the goroutine
gm.StopGoroutine("cancel-test")
// Wait for cancellation
time.Sleep(50 * time.Millisecond)
if !canceled.Load() {
t.Error("Expected goroutine to be canceled")
}
}
func TestStartGoroutineWithPanic(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("panic-test", func(ctx context.Context) {
executed.Store(true)
panic("test panic")
})
// Give goroutine time to panic and recover
time.Sleep(100 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute before panic")
}
// Check that goroutine is marked as not running after panic
status := gm.GetStatus()
if goroutineStatus, exists := status["panic-test"]; exists {
if goroutineStatus.Running {
t.Error("Expected goroutine to be marked as not running after panic")
}
}
// Manager should still be functional
counter := atomic.Int32{}
gm.StartGoroutine("after-panic", func(ctx context.Context) {
counter.Add(1)
})
time.Sleep(50 * time.Millisecond)
if counter.Load() != 1 {
t.Error("Expected manager to still be functional after panic recovery")
}
}
// Test Periodic Tasks
func TestStartPeriodicTask(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("periodic-test", 50*time.Millisecond, func() {
counter.Add(1)
})
// Wait for multiple executions
time.Sleep(160 * time.Millisecond)
// Should have executed at least 2-3 times
count := counter.Load()
if count < 2 {
t.Errorf("Expected periodic task to execute at least 2 times, got %d", count)
}
}
func TestStartPeriodicTaskCancellation(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("cancel-periodic", 50*time.Millisecond, func() {
counter.Add(1)
})
// Wait for some executions
time.Sleep(120 * time.Millisecond)
// Stop the task
gm.StopGoroutine("cancel-periodic")
countBeforeStop := counter.Load()
// Wait and verify no more executions
time.Sleep(120 * time.Millisecond)
countAfterStop := counter.Load()
// Allow 1 additional execution (could be in progress when stopped)
if countAfterStop > countBeforeStop+1 {
t.Errorf("Expected periodic task to stop executing, before: %d, after: %d",
countBeforeStop, countAfterStop)
}
}
// Test Stopping Goroutines
func TestStopGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
stopped := atomic.Bool{}
gm.StartGoroutine("stop-test", func(ctx context.Context) {
<-ctx.Done()
stopped.Store(true)
})
// Wait for goroutine to start
time.Sleep(50 * time.Millisecond)
gm.StopGoroutine("stop-test")
// Wait for goroutine to stop
time.Sleep(50 * time.Millisecond)
if !stopped.Load() {
t.Error("Expected goroutine to be stopped")
}
status := gm.GetStatus()
if goroutineStatus, exists := status["stop-test"]; exists {
if goroutineStatus.Running {
t.Error("Expected goroutine to be marked as not running")
}
}
}
func TestStopGoroutineNonExistent(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Should not panic or error when stopping non-existent goroutine
gm.StopGoroutine("non-existent")
}
func TestStopGoroutineAlreadyStopped(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
gm.StartGoroutine("already-stopped", func(ctx context.Context) {
// Exit immediately
})
// Wait for goroutine to finish
time.Sleep(50 * time.Millisecond)
// Try to stop already-stopped goroutine (should be safe)
gm.StopGoroutine("already-stopped")
}
// Test Shutdown
func TestShutdownGraceful(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
counter := atomic.Int32{}
// Start multiple goroutines
for i := 0; i < 5; i++ {
name := "goroutine-" + string(rune('0'+i))
gm.StartGoroutine(name, func(ctx context.Context) {
counter.Add(1)
<-ctx.Done()
counter.Add(-1)
})
}
// Wait for all to start
time.Sleep(100 * time.Millisecond)
if counter.Load() != 5 {
t.Errorf("Expected 5 goroutines running, got %d", counter.Load())
}
// Shutdown with generous timeout
err := gm.Shutdown(time.Second)
if err != nil {
t.Errorf("Expected graceful shutdown, got error: %v", err)
}
if counter.Load() != 0 {
t.Errorf("Expected all goroutines to complete cleanup, got %d still running", counter.Load())
}
}
func TestShutdownWithTimeout(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Start a goroutine that ignores cancellation (bad behavior, but testing timeout)
gm.StartGoroutine("stubborn", func(ctx context.Context) {
// Simulate a goroutine that takes too long to stop
time.Sleep(500 * time.Millisecond)
})
time.Sleep(50 * time.Millisecond)
// Shutdown with very short timeout
err := gm.Shutdown(10 * time.Millisecond)
if err == nil {
t.Error("Expected timeout error")
}
if err != ErrShutdownTimeout {
t.Errorf("Expected ErrShutdownTimeout, got %v", err)
}
}
func TestShutdownEmpty(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Shutdown with no goroutines should succeed immediately
err := gm.Shutdown(time.Second)
if err != nil {
t.Errorf("Expected no error for empty shutdown, got: %v", err)
}
}
// Test Status
func TestGetStatus(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Start multiple goroutines with different states
gm.StartGoroutine("running", func(ctx context.Context) {
<-ctx.Done()
})
gm.StartGoroutine("quick", func(ctx context.Context) {
// Exits immediately
})
time.Sleep(50 * time.Millisecond)
status := gm.GetStatus()
if len(status) != 2 {
t.Errorf("Expected 2 goroutines in status, got %d", len(status))
}
if runningStatus, exists := status["running"]; exists {
if !runningStatus.Running {
t.Error("Expected 'running' goroutine to be marked as running")
}
if runningStatus.Name != "running" {
t.Errorf("Expected name 'running', got %s", runningStatus.Name)
}
if runningStatus.StartTime.IsZero() {
t.Error("Expected non-zero start time")
}
if runningStatus.Runtime <= 0 {
t.Error("Expected positive runtime")
}
} else {
t.Error("Expected 'running' goroutine in status")
}
if quickStatus, exists := status["quick"]; exists {
if quickStatus.Running {
t.Error("Expected 'quick' goroutine to be marked as not running")
}
} else {
t.Error("Expected 'quick' goroutine in status")
}
}
func TestGetStatusEmpty(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
status := gm.GetStatus()
if status == nil {
t.Fatal("Expected non-nil status map")
}
if len(status) != 0 {
t.Errorf("Expected empty status, got %d entries", len(status))
}
}
// Test Concurrent Operations
func TestConcurrentStartGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(2 * time.Second)
counter := atomic.Int32{}
const numGoroutines = 50
// Start many goroutines concurrently
for i := 0; i < numGoroutines; i++ {
go func(id int) {
name := "concurrent-" + string(rune('0'+id%10)) + string(rune('0'+id/10))
gm.StartGoroutine(name, func(ctx context.Context) {
counter.Add(1)
time.Sleep(50 * time.Millisecond)
counter.Add(-1)
})
}(i)
}
// Wait for all to start
time.Sleep(150 * time.Millisecond)
// Verify goroutines are tracked
status := gm.GetStatus()
if len(status) < numGoroutines/2 {
t.Errorf("Expected at least %d goroutines, got %d", numGoroutines/2, len(status))
}
}
func TestConcurrentStopGoroutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
const numGoroutines = 20
// Start goroutines
for i := 0; i < numGoroutines; i++ {
name := "stop-concurrent-" + string(rune('0'+i%10))
gm.StartGoroutine(name, func(ctx context.Context) {
<-ctx.Done()
})
}
time.Sleep(50 * time.Millisecond)
// Stop all concurrently
for i := 0; i < numGoroutines; i++ {
go func(id int) {
name := "stop-concurrent-" + string(rune('0'+id%10))
gm.StopGoroutine(name)
}(i)
}
time.Sleep(100 * time.Millisecond)
// Verify all stopped
status := gm.GetStatus()
for _, s := range status {
if s.Running {
t.Errorf("Expected goroutine %s to be stopped", s.Name)
}
}
}
func TestConcurrentGetStatus(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
// Start some goroutines
for i := 0; i < 10; i++ {
name := "status-test-" + string(rune('0'+i))
gm.StartGoroutine(name, func(ctx context.Context) {
<-ctx.Done()
})
}
// Concurrently read status many times (should not race)
done := make(chan struct{})
for i := 0; i < 20; i++ {
go func() {
for j := 0; j < 100; j++ {
_ = gm.GetStatus()
}
done <- struct{}{}
}()
}
// Wait for all concurrent reads
for i := 0; i < 20; i++ {
<-done
}
}
// Test Error Cases
func TestShutdownTimeoutError(t *testing.T) {
err := ErrShutdownTimeout
if err.Error() != "shutdown timeout: some goroutines did not stop in time" {
t.Errorf("Unexpected error message: %s", err.Error())
}
}
// Test Edge Cases
func TestStartGoroutineAfterShutdown(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// Shutdown immediately
_ = gm.Shutdown(time.Second)
executed := atomic.Bool{}
// Try to start goroutine after shutdown
gm.StartGoroutine("after-shutdown", func(ctx context.Context) {
executed.Store(true)
<-ctx.Done()
})
time.Sleep(50 * time.Millisecond)
// Goroutine should have started but context already canceled
// It may or may not execute depending on timing, but shouldn't panic
status := gm.GetStatus()
if _, exists := status["after-shutdown"]; exists {
// If it's in status, it was tracked (acceptable)
t.Log("Goroutine was tracked even after shutdown")
}
}
func TestMultipleShutdowns(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
// First shutdown
err1 := gm.Shutdown(time.Second)
if err1 != nil {
t.Errorf("Expected first shutdown to succeed, got: %v", err1)
}
// Second shutdown (should not panic or error)
err2 := gm.Shutdown(time.Second)
if err2 != nil {
t.Errorf("Expected second shutdown to succeed, got: %v", err2)
}
}
func TestGoroutineWithImmediateReturn(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
executed := atomic.Bool{}
gm.StartGoroutine("immediate", func(ctx context.Context) {
executed.Store(true)
// Return immediately
})
time.Sleep(50 * time.Millisecond)
if !executed.Load() {
t.Error("Expected goroutine to execute")
}
status := gm.GetStatus()
if goroutineStatus, exists := status["immediate"]; exists {
if goroutineStatus.Running {
t.Error("Expected immediately-returning goroutine to be marked as not running")
}
}
}
func TestPeriodicTaskPanicRecovery(t *testing.T) {
logger := GetSingletonNoOpLogger()
gm := NewGoroutineManager(logger)
defer gm.Shutdown(time.Second)
counter := atomic.Int32{}
gm.StartPeriodicTask("panic-periodic", 50*time.Millisecond, func() {
counter.Add(1)
if counter.Load() == 2 {
panic("periodic panic")
}
})
// Wait for panic to occur
time.Sleep(200 * time.Millisecond)
// After panic, the goroutine should have stopped
status := gm.GetStatus()
if goroutineStatus, exists := status["panic-periodic"]; exists {
if goroutineStatus.Running {
t.Error("Expected panicked periodic task to stop")
}
}
}
-764
View File
@@ -1,764 +0,0 @@
package handlers
import (
"errors"
"net/http"
"sync"
"testing"
"time"
)
// ============================================================================
// OAuth Handler Tests
// ============================================================================
func TestOAuthHandler(t *testing.T) {
t.Run("HandleAuthorizationRequest", func(t *testing.T) {
// Test authorization request handling logic
logger := &MockLogger{}
tests := []struct {
name string
requestURL string
expectedStatus int
checkLocation bool
}{
{
name: "Valid authorization request",
requestURL: "/auth/login",
expectedStatus: http.StatusFound,
checkLocation: true,
},
{
name: "With return URL",
requestURL: "/auth/login?return=/dashboard",
expectedStatus: http.StatusFound,
checkLocation: true,
},
}
// Test the test case structure
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Verify test case parameters
if test.requestURL == "" {
t.Error("Request URL should not be empty")
}
if test.expectedStatus == 0 {
t.Error("Expected status should be set")
}
// In a real implementation, this would test the actual handler
t.Logf("Testing %s with URL %s expecting status %d", test.name, test.requestURL, test.expectedStatus)
})
}
// Verify logger doesn't cause issues
logger.Debugf("Authorization request test completed")
})
t.Run("HandleCallbackRequest", func(t *testing.T) {
// Test callback request handling with existing mocks
sessionManager := NewMockSessionManager()
logger := &MockLogger{}
tests := []struct {
name string
queryParams string
expectedStatus int
expectError bool
}{
{
name: "Valid callback with code",
queryParams: "code=test-code&state=test-state",
expectedStatus: http.StatusFound,
expectError: false,
},
{
name: "Callback with error",
queryParams: "error=access_denied&error_description=User denied access",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
{
name: "Missing code",
queryParams: "state=test-state",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
{
name: "Missing state",
queryParams: "code=test-code",
expectedStatus: http.StatusBadRequest,
expectError: true,
},
}
// Test the callback scenarios
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Verify test case parameters
if test.queryParams == "" && !test.expectError {
t.Error("Query params should not be empty for successful cases")
}
if test.expectedStatus == 0 {
t.Error("Expected status should be set")
}
// Test session manager functionality
if sessionManager != nil {
t.Logf("Session manager available for test %s", test.name)
}
t.Logf("Testing %s with params %s expecting status %d", test.name, test.queryParams, test.expectedStatus)
})
}
// Verify logger doesn't cause issues
logger.Debugf("Callback request test completed")
})
t.Run("HandleLogout", func(t *testing.T) {
// Test logout functionality with mock implementations
sessionManager := NewMockSessionManager()
logger := &MockLogger{}
// Test session clearing
mockReq := &http.Request{}
session, err := sessionManager.GetSession(mockReq)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set up authenticated session
err = session.SetAuthenticated(true)
if err != nil {
t.Fatalf("Failed to set authentication: %v", err)
}
session.SetIDToken("test-token")
// Verify session is authenticated
if !session.GetAuthenticated() {
t.Error("Session should be authenticated before logout")
}
// Test logout by clearing session
// session.Clear() // Method not implemented in SessionData
// Additional logout verification would go here
// Verify logger doesn't cause issues
logger.Debugf("Logout test completed")
t.Log("Logout test completed successfully")
})
}
// ============================================================================
// Auth Handler Tests
// ============================================================================
func TestAuthHandler(t *testing.T) {
t.Run("HandleAuthentication", func(t *testing.T) {
// Test authentication handling with mock types
// validator := &MockTokenValidator{valid: true} // Currently unused
/*
handler := &MockAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
tests := []struct {
name string
setupSession func(*MockSession)
expectedStatus int
expectNext bool
}{
{
name: "Authenticated user",
setupSession: func(s *MockSession) {
s.SetAuthenticated(true)
s.SetIDToken("valid-token")
},
expectedStatus: http.StatusOK,
expectNext: true,
},
{
name: "Unauthenticated user",
setupSession: func(s *MockSession) {
s.SetAuthenticated(false)
},
expectedStatus: http.StatusUnauthorized,
expectNext: false,
},
{
name: "Expired token",
setupSession: func(s *MockSession) {
s.SetAuthenticated(true)
s.SetIDToken("expired-token")
},
expectedStatus: http.StatusUnauthorized,
expectNext: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
t.Run("HandleRefreshToken", func(t *testing.T) {
// Test authentication handling with mock types
// validator := &MockTokenValidator{valid: true} // Currently unused
tests := []struct {
name string
refreshToken string
mockResponse *MockTokenResponse
mockError error
expectSuccess bool
}{
{
name: "Successful refresh",
refreshToken: "valid-refresh-token",
mockResponse: &MockTokenResponse{
AccessToken: "new-access-token",
IDToken: "new-id-token",
RefreshToken: "new-refresh-token",
},
expectSuccess: true,
},
{
name: "Failed refresh",
refreshToken: "invalid-refresh-token",
mockError: errors.New("invalid_grant"),
expectSuccess: false,
},
{
name: "Empty refresh token",
refreshToken: "",
expectSuccess: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Error Handler Tests
// ============================================================================
func TestErrorHandler(t *testing.T) {
t.Run("HandleHTTPErrors", func(t *testing.T) {
// Test with mock implementations
/*
handler := &MockErrorHandler{
logger: &MockLogger{},
}
*/
tests := []struct {
name string
errorCode int
errorMessage string
isAjax bool
expectedStatus int
expectedBody string
}{
{
name: "401 Unauthorized",
errorCode: http.StatusUnauthorized,
errorMessage: "Authentication required",
isAjax: false,
expectedStatus: http.StatusUnauthorized,
expectedBody: "Authentication required",
},
{
name: "403 Forbidden",
errorCode: http.StatusForbidden,
errorMessage: "Access denied",
isAjax: false,
expectedStatus: http.StatusForbidden,
expectedBody: "Access denied",
},
{
name: "500 Internal Server Error",
errorCode: http.StatusInternalServerError,
errorMessage: "Internal server error",
isAjax: false,
expectedStatus: http.StatusInternalServerError,
expectedBody: "Internal server error",
},
{
name: "Ajax 401",
errorCode: http.StatusUnauthorized,
errorMessage: "Token expired",
isAjax: true,
expectedStatus: http.StatusUnauthorized,
expectedBody: `{"error":"unauthorized","message":"Token expired"}`,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
t.Run("RecoverFromPanic", func(t *testing.T) {
// Test with mock implementations
/*
handler := &MockErrorHandler{
logger: &MockLogger{},
}
*/
tests := []struct {
name string
panicValue interface{}
expectError bool
}{
{
name: "String panic",
panicValue: "something went wrong",
expectError: true,
},
{
name: "Error panic",
panicValue: errors.New("critical error"),
expectError: true,
},
{
name: "Nil panic",
panicValue: nil,
expectError: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Azure OAuth Callback Tests
// ============================================================================
func TestAzureOAuthCallback(t *testing.T) {
t.Run("AzureSpecificClaims", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
azureClaims := map[string]interface{}{
"oid": "object-id",
"tid": "tenant-id",
"preferred_username": "user@example.com",
"name": "Test User",
"email": "user@example.com",
"groups": []string{"group1", "group2"},
}
// Test would go here when properly implemented
_ = azureClaims
})
t.Run("AzureTokenValidation", func(t *testing.T) {
// Test with mock validator types
/*
validator := &MockAzureTokenValidator{
tenantID: "test-tenant",
clientID: "test-client",
}
*/
tests := []struct {
name string
token string
claims map[string]interface{}
expectValid bool
}{
{
name: "Valid Azure token",
token: "valid-azure-token",
claims: map[string]interface{}{
"aud": "test-client",
"tid": "test-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: true,
},
{
name: "Wrong tenant",
token: "wrong-tenant-token",
claims: map[string]interface{}{
"aud": "test-client",
"tid": "wrong-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: false,
},
{
name: "Wrong audience",
token: "wrong-audience-token",
claims: map[string]interface{}{
"aud": "wrong-client",
"tid": "test-tenant",
"exp": float64(time.Now().Add(time.Hour).Unix()),
},
expectValid: false,
},
}
// Test the authentication test cases
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test with mock session
mockSession := &MockSession{values: make(map[string]interface{})}
// Use mock session to avoid unused variable error
_ = mockSession
t.Logf("Testing %s", test.name)
})
}
})
}
// ============================================================================
// Concurrent Handler Tests
// ============================================================================
func TestConcurrentHandlers(t *testing.T) {
t.Run("ConcurrentCallbacks", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
var wg sync.WaitGroup
successCount := int32(0)
errorCount := int32(0)
// Test would go here when properly implemented
wg.Wait() // Proper usage instead of assignment
_ = successCount
_ = errorCount
})
t.Run("ConcurrentLogouts", func(t *testing.T) {
// Test with mock configuration
/*
handler := &OAuthHandler{
logger: &MockLogger{},
sessionManager: NewMockSessionManager(),
}
*/
var wg sync.WaitGroup
logoutCount := int32(0)
// Test would go here when properly implemented
wg.Wait() // Proper usage instead of assignment
_ = logoutCount
})
}
// ============================================================================
// Mock Implementations
// ============================================================================
type MockSessionManager struct {
sessions map[string]*MockSession
mu sync.RWMutex
}
func NewMockSessionManager() *MockSessionManager {
return &MockSessionManager{
sessions: make(map[string]*MockSession),
}
}
func (m *MockSessionManager) GetSession(r *http.Request) (SessionData, error) {
m.mu.Lock()
defer m.mu.Unlock()
sessionID := "test-session"
if session, exists := m.sessions[sessionID]; exists {
return session, nil
}
session := &MockSession{
values: make(map[string]interface{}),
}
m.sessions[sessionID] = session
return session, nil
}
type MockSession struct {
values map[string]interface{}
mu sync.RWMutex
}
func (s *MockSession) SetAuthenticated(auth bool) error {
s.mu.Lock()
defer s.mu.Unlock()
s.values["authenticated"] = auth
return nil
}
func (s *MockSession) GetAuthenticated() bool {
s.mu.RLock()
defer s.mu.RUnlock()
auth, ok := s.values["authenticated"].(bool)
return ok && auth
}
func (s *MockSession) SetIDToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["id_token"] = token
}
func (s *MockSession) GetIDToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["id_token"].(string)
return token
}
func (s *MockSession) SetAccessToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["access_token"] = token
}
func (s *MockSession) GetAccessToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["access_token"].(string)
return token
}
func (s *MockSession) SetRefreshToken(token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["refresh_token"] = token
}
func (s *MockSession) GetRefreshToken() string {
s.mu.RLock()
defer s.mu.RUnlock()
token, _ := s.values["refresh_token"].(string)
return token
}
func (s *MockSession) SetState(state string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["state"] = state
}
func (s *MockSession) GetState() string {
s.mu.RLock()
defer s.mu.RUnlock()
state, _ := s.values["state"].(string)
return state
}
func (s *MockSession) SetClaims(claims map[string]interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["claims"] = claims
}
func (s *MockSession) GetClaims() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
claims, _ := s.values["claims"].(map[string]interface{})
return claims
}
// Additional SessionData interface methods to match real interface
func (s *MockSession) GetCSRF() string {
s.mu.RLock()
defer s.mu.RUnlock()
csrf, _ := s.values["csrf"].(string)
return csrf
}
func (s *MockSession) GetNonce() string {
s.mu.RLock()
defer s.mu.RUnlock()
nonce, _ := s.values["nonce"].(string)
return nonce
}
func (s *MockSession) GetCodeVerifier() string {
s.mu.RLock()
defer s.mu.RUnlock()
verifier, _ := s.values["code_verifier"].(string)
return verifier
}
func (s *MockSession) GetIncomingPath() string {
s.mu.RLock()
defer s.mu.RUnlock()
path, _ := s.values["incoming_path"].(string)
return path
}
func (s *MockSession) SetEmail(email string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["email"] = email
}
func (s *MockSession) GetEmail() string {
s.mu.RLock()
defer s.mu.RUnlock()
email, _ := s.values["email"].(string)
return email
}
func (s *MockSession) SetCSRF(csrf string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["csrf"] = csrf
}
func (s *MockSession) SetNonce(nonce string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["nonce"] = nonce
}
func (s *MockSession) SetCodeVerifier(verifier string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["code_verifier"] = verifier
}
func (s *MockSession) SetIncomingPath(path string) {
s.mu.Lock()
defer s.mu.Unlock()
s.values["incoming_path"] = path
}
func (s *MockSession) ResetRedirectCount() {
s.mu.Lock()
defer s.mu.Unlock()
s.values["redirect_count"] = 0
}
func (s *MockSession) Save(r *http.Request, w http.ResponseWriter) error {
return nil
}
func (s *MockSession) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.values = make(map[string]interface{})
}
func (s *MockSession) returnToPoolSafely() {
// No-op for mock
}
type MockTokenValidator struct {
valid bool
}
func (v *MockTokenValidator) Validate(token string) bool {
if token == "expired-token" {
return false
}
return v.valid
}
// ============================================================================
// Mock Handler Type Definitions (for testing)
// ============================================================================
// These mock handlers are simplified versions for testing purposes
// They don't match the actual handler implementations
type MockAuthHandler struct{}
type MockErrorHandler struct{}
type MockAzureTokenValidator struct {
tenantID string
clientID string
}
func (v *MockAzureTokenValidator) ValidateAzureToken(token string, claims map[string]interface{}) bool {
// Validate tenant ID
if tid, ok := claims["tid"].(string); !ok || tid != v.tenantID {
return false
}
// Validate audience
if aud, ok := claims["aud"].(string); !ok || aud != v.clientID {
return false
}
// Validate expiration
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return false
}
}
return true
}
// ============================================================================
// Helper Types and Mock Logger
// ============================================================================
type MockLogger struct{}
func (l *MockLogger) Debugf(format string, args ...interface{}) {}
func (l *MockLogger) Errorf(format string, args ...interface{}) {}
func (l *MockLogger) Error(msg string) {}
type MockTokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
}
-308
View File
@@ -1,308 +0,0 @@
// Package handlers provides HTTP request handlers for the OIDC middleware.
package handlers
import (
"context"
"fmt"
"net/http"
"strings"
)
// OAuthHandler handles OAuth callback requests
type OAuthHandler struct {
logger Logger
sessionManager SessionManager
tokenExchanger TokenExchanger
tokenVerifier TokenVerifier
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
isAllowedDomainFunc func(email string) bool
redirURLPath string
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
}
// Logger interface for dependency injection
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
Error(msg string)
}
// SessionManager interface for session operations
type SessionManager interface {
GetSession(req *http.Request) (SessionData, error)
}
// SessionData interface for session data operations
type SessionData interface {
GetCSRF() string
GetNonce() string
GetCodeVerifier() string
GetIncomingPath() string
GetAuthenticated() bool
GetAccessToken() string
GetRefreshToken() string
GetIDToken() string
GetEmail() string
SetAuthenticated(bool) error
SetEmail(string)
SetIDToken(string)
SetAccessToken(string)
SetRefreshToken(string)
SetCSRF(string)
SetNonce(string)
SetCodeVerifier(string)
SetIncomingPath(string)
ResetRedirectCount()
Save(req *http.Request, rw http.ResponseWriter) error
returnToPoolSafely()
}
// TokenExchanger interface for token operations
type TokenExchanger interface {
ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error)
}
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// TokenResponse represents the response from token exchange
type TokenResponse struct {
IDToken string
AccessToken string
RefreshToken string
}
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
isAllowedDomainFunc func(string) bool, redirURLPath string,
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
return &OAuthHandler{
logger: logger,
sessionManager: sessionManager,
tokenExchanger: tokenExchanger,
tokenVerifier: tokenVerifier,
extractClaimsFunc: extractClaimsFunc,
isAllowedDomainFunc: isAllowedDomainFunc,
redirURLPath: redirURLPath,
sendErrorResponseFunc: sendErrorResponseFunc,
}
}
// HandleCallback handles OAuth callback requests
func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := h.sessionManager.GetSession(req)
if err != nil {
h.logger.Errorf("Session error during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Session error during callback", http.StatusInternalServerError)
return
}
defer session.returnToPoolSafely()
h.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Debug logging for cookie configuration
h.logger.Debugf("Callback request headers - Host: %s, X-Forwarded-Host: %s, X-Forwarded-Proto: %s",
req.Host, req.Header.Get("X-Forwarded-Host"), req.Header.Get("X-Forwarded-Proto"))
// Log all cookies in the request for debugging
cookies := req.Cookies()
h.logger.Debugf("Total cookies in callback request: %d", len(cookies))
for _, cookie := range cookies {
if strings.HasPrefix(cookie.Name, "_oidc_") {
h.logger.Debugf("Cookie found - Name: %s, Domain: %s, Path: %s, SameSite: %v, Secure: %v, HttpOnly: %v, Value length: %d",
cookie.Name, cookie.Domain, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly, len(cookie.Value))
}
}
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error")
}
h.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
h.sendErrorResponseFunc(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
state := req.URL.Query().Get("state")
if state == "" {
h.logger.Error("No state in callback")
h.sendErrorResponseFunc(rw, req, "State parameter missing in callback", http.StatusBadRequest)
return
}
// Debug log the state parameter received
h.logger.Debugf("State parameter received in callback: %s (length: %d)", state, len(state))
csrfToken := session.GetCSRF()
if csrfToken == "" {
h.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
session.GetAuthenticated(), req.URL.String())
// Enhanced debugging for missing CSRF token
cookie, err := req.Cookie("_oidc_raczylo_m")
if err != nil {
h.logger.Errorf("Main session cookie not found in request: %v", err)
h.logger.Debugf("Available cookies: %v", req.Header.Get("Cookie"))
} else {
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
cookie.Domain, cookie.Path, cookie.Secure, cookie.HttpOnly, cookie.SameSite)
}
// Log session state for debugging
h.logger.Debugf("Session state during CSRF check - Authenticated: %v, Has AccessToken: %v",
session.GetAuthenticated(), session.GetAccessToken() != "")
h.sendErrorResponseFunc(rw, req, "CSRF token missing in session", http.StatusBadRequest)
return
}
// Debug log successful CSRF token retrieval
h.logger.Debugf("CSRF token retrieved from session: %s (length: %d)", csrfToken, len(csrfToken))
if state != csrfToken {
h.logger.Error("State parameter does not match CSRF token in session during callback")
h.sendErrorResponseFunc(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
code := req.URL.Query().Get("code")
if code == "" {
h.logger.Error("No code in callback")
h.sendErrorResponseFunc(rw, req, "No authorization code received in callback", http.StatusBadRequest)
return
}
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := h.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
h.logger.Errorf("Failed to exchange code for token during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
if err = h.tokenVerifier.VerifyToken(tokenResponse.IDToken); err != nil {
h.logger.Errorf("Failed to verify id_token during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := h.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
h.logger.Errorf("Failed to extract claims during callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
h.logger.Error("Nonce claim missing in id_token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
h.logger.Error("Nonce not found in session during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
h.logger.Error("Nonce claim does not match session nonce during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
email, _ := claims["email"].(string)
if email == "" {
h.logger.Errorf("Email claim missing or empty in token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
}
if !h.isAllowedDomainFunc(email) {
h.logger.Errorf("Disallowed email domain during callback: %s", email)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
return
}
if err := session.SetAuthenticated(true); err != nil {
h.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
session.ResetRedirectCount()
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != h.redirURLPath {
redirectPath = incomingPath
}
session.SetIncomingPath("")
if err := session.Save(req, rw); err != nil {
h.logger.Errorf("Failed to save session after callback: %v", err)
h.sendErrorResponseFunc(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
return
}
h.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// URLHelper provides utility methods for URL operations
type URLHelper struct {
logger Logger
}
// NewURLHelper creates a new URL helper
func NewURLHelper(logger Logger) *URLHelper {
return &URLHelper{logger: logger}
}
// DetermineExcludedURL checks if a URL path should bypass OIDC authentication.
// It compares the request path against configured excluded URL prefixes.
func (h *URLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
for excludedURL := range excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
h.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
return false
}
// DetermineScheme determines the URL scheme for building redirect URLs.
// It checks X-Forwarded-Proto header first, then TLS presence.
func (h *URLHelper) DetermineScheme(req *http.Request) string {
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
return "https"
}
return "http"
}
// DetermineHost determines the host for building redirect URLs.
// It checks X-Forwarded-Host header first, then falls back to req.Host.
func (h *URLHelper) DetermineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
}
return req.Host
}
+74 -13
View File
@@ -13,8 +13,25 @@ import (
"net/url"
"strings"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// newUUIDv4 returns an RFC 4122 v4 UUID string (e.g.
// "f47ac10b-58cc-4372-a567-0e02b2c3d479") backed by crypto/rand. Used for CSRF
// tokens and other opaque random identifiers — replaces github.com/google/uuid
// to keep the plugin stdlib-only on the production path.
func newUUIDv4() (string, error) {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return "", fmt.Errorf("could not generate UUID: %w", err)
}
b[6] = (b[6] & 0x0f) | 0x40 // version 4
b[8] = (b[8] & 0x3f) | 0x80 // RFC 4122 variant
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]), nil
}
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
// The nonce is used to prevent replay attacks and associate client sessions with ID tokens.
// Returns:
@@ -90,9 +107,12 @@ type TokenResponse struct {
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant)
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
"grant_type": {grantType},
}
// client_id is sent in the body for every method except client_secret_basic,
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
data.Set("client_id", t.clientID)
}
if grantType == "authorization_code" {
@@ -109,7 +129,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
client := t.tokenHTTPClient
if client == nil {
// Use shared transport pool to prevent memory leaks
jar, _ := cookiejar.New(nil)
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
pooledClient := CreateTokenHTTPClient()
client = &http.Client{
Transport: pooledClient.Transport,
@@ -124,24 +144,46 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
}
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
// Read tokenURL with RLock — needed as audience for private_key_jwt (RFC 7523 §3).
t.metadataMu.RLock()
tokenURL := t.tokenURL
t.metadataMu.RUnlock()
useBasicAuth := false
if t.clientAssertion != nil {
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
if err != nil {
return nil, fmt.Errorf("failed to sign client assertion: %w", err)
}
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
data.Set("client_assertion", assertion)
} else if t.clientAuthMethod == "client_secret_basic" {
useBasicAuth = true
} else {
data.Set("client_secret", t.clientSecret)
}
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if useBasicAuth {
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
defer func() {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body) // Safe to ignore: draining response body on defer
_ = resp.Body.Close() // Safe to ignore: closing body on defer
}()
if resp.StatusCode != http.StatusOK {
limitReader := io.LimitReader(resp.Body, 1024*10)
bodyBytes, _ := io.ReadAll(limitReader)
bodyBytes, _ := io.ReadAll(limitReader) // Safe to ignore: reading error body for diagnostics
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
@@ -232,7 +274,7 @@ func NewTokenCache() *TokenCache {
// - expiration: The duration for which the cache entry should be valid
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
_ = tc.cache.Set(token, claims, expiration) // Safe to ignore: cache failures are non-critical
}
// Get retrieves cached claims for a token.
@@ -329,6 +371,7 @@ func createStringMap(keys []string) map[string]struct{} {
// and redirects to the provider's logout endpoint or configured post-logout URI.
// It handles potential errors during session retrieval or clearing.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing logout request")
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
@@ -344,8 +387,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
host := t.determineHost(req)
scheme := t.determineScheme(req)
host := utils.DetermineHost(req)
scheme := utils.DetermineScheme(req, t.forceHTTPS)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
@@ -355,8 +398,13 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
// Read endSessionURL with RLock
t.metadataMu.RLock()
endSessionURL := t.endSessionURL
t.metadataMu.RUnlock()
if endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
@@ -395,6 +443,19 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
return u.String(), nil
}
// setOAuthBasicAuth sets the Authorization header per RFC 6749 §2.3.1: the
// client_id and client_secret are form-urlencoded individually, joined with a
// colon, then base64-encoded. This differs from http.Request.SetBasicAuth,
// which skips the form-urlencode step — that matters for credentials with
// reserved characters (`:`, `@`, `+`, `%`, etc.) where the wire format would
// otherwise diverge from what the spec mandates.
func setOAuthBasicAuth(req *http.Request, clientID, clientSecret string) {
user := url.QueryEscape(clientID)
pass := url.QueryEscape(clientSecret)
auth := base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
req.Header.Set("Authorization", "Basic "+auth)
}
// deduplicateScopes removes duplicate scopes from a slice while preserving order.
// This ensures that OAuth scope parameters don't contain duplicates which could
// cause issues with some authorization servers.
+29
View File
@@ -0,0 +1,29 @@
package traefikoidc
import (
"regexp"
"testing"
)
// TestNewUUIDv4 verifies the in-house UUID v4 generator produces RFC 4122
// compliant identifiers. Locks in the replacement for github.com/google/uuid
// — a regression here would weaken the CSRF token used in the OIDC flow.
func TestNewUUIDv4(t *testing.T) {
rfc4122v4 := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
const samples = 1000
seen := make(map[string]struct{}, samples)
for i := 0; i < samples; i++ {
got, err := newUUIDv4()
if err != nil {
t.Fatalf("newUUIDv4 failed: %v", err)
}
if !rfc4122v4.MatchString(got) {
t.Fatalf("UUID %q does not match RFC 4122 v4 format", got)
}
if _, dup := seen[got]; dup {
t.Fatalf("duplicate UUID emitted within %d samples: %q", samples, got)
}
seen[got] = struct{}{}
}
}
+35 -22
View File
@@ -3,6 +3,7 @@ package traefikoidc
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
@@ -12,27 +13,26 @@ import (
// HTTPClientConfig provides configuration for creating HTTP clients
type HTTPClientConfig struct {
// Timeout for the entire request
Timeout time.Duration
// MaxRedirects allowed (0 means follow Go's default of 10)
MaxRedirects int
// UseCookieJar enables cookie jar for the client
UseCookieJar bool
// Connection settings
IdleConnTimeout time.Duration
MaxIdleConns int
ReadBufferSize int
DialTimeout time.Duration
KeepAlive time.Duration
TLSHandshakeTimeout time.Duration
ResponseHeaderTimeout time.Duration
ExpectContinueTimeout time.Duration
IdleConnTimeout time.Duration
// Connection pool settings
MaxIdleConns int
MaxIdleConnsPerHost int
MaxConnsPerHost int
// Buffer settings
WriteBufferSize int
ReadBufferSize int
// Feature flags
MaxRedirects int
MaxIdleConnsPerHost int
Timeout time.Duration
MaxConnsPerHost int
WriteBufferSize int
// RootCAs is an optional certificate pool used for TLS verification.
// A nil pool means "use the system trust store" (default behavior).
RootCAs *x509.CertPool
// InsecureSkipVerify disables TLS certificate verification.
// ONLY set this for local development against self-signed certificates.
InsecureSkipVerify bool
UseCookieJar bool
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
@@ -49,10 +49,10 @@ func DefaultHTTPClientConfig() HTTPClientConfig {
TLSHandshakeTimeout: 2 * time.Second,
ResponseHeaderTimeout: 3 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
IdleConnTimeout: 5 * time.Second,
MaxIdleConns: 20, // SECURITY FIX: Reduced from 100 to limit resource usage
MaxIdleConnsPerHost: 2, // SECURITY FIX: Reduced from 10 to prevent connection exhaustion
MaxConnsPerHost: 5, // SECURITY FIX: Reduced from 10 to limit concurrent connections
IdleConnTimeout: 30 * time.Second, // OPTIMIZATION: Increased for better connection reuse
MaxIdleConns: 50, // OPTIMIZATION: Increased from 20 for better connection pooling
MaxIdleConnsPerHost: 10, // OPTIMIZATION: Increased from 2 for better connection reuse
MaxConnsPerHost: 20, // OPTIMIZATION: Increased from 5 while maintaining security
WriteBufferSize: 4096,
ReadBufferSize: 4096,
ForceHTTP2: true,
@@ -70,6 +70,18 @@ func TokenHTTPClientConfig() HTTPClientConfig {
return config
}
// OIDCProviderHTTPClientConfig returns configuration optimized for OIDC provider calls
func OIDCProviderHTTPClientConfig() HTTPClientConfig {
config := DefaultHTTPClientConfig()
config.Timeout = 15 * time.Second // Slightly longer for OIDC operations
config.MaxIdleConns = 100 // Higher pool for frequent OIDC calls
config.MaxIdleConnsPerHost = 25 // More connections per OIDC provider
config.MaxConnsPerHost = 50 // Allow more concurrent requests to OIDC provider
config.IdleConnTimeout = 90 * time.Second // Keep connections alive longer for reuse
config.UseCookieJar = true // Enable cookie jar for session management
return config
}
// HTTPClientFactory provides methods for creating configured HTTP clients
type HTTPClientFactory struct{}
@@ -198,7 +210,8 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
PreferServerCipherSuites: true,
InsecureSkipVerify: false, // Always verify certificates
RootCAs: config.RootCAs,
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
},
ForceAttemptHTTP2: config.ForceHTTP2,
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
@@ -233,7 +246,7 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
// Add cookie jar if requested
if config.UseCookieJar {
jar, _ := cookiejar.New(nil)
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
client.Jar = jar
}
+210
View File
@@ -0,0 +1,210 @@
package traefikoidc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestOIDCProviderHTTPClientConfigUnit tests OIDCProviderHTTPClientConfig function
func TestOIDCProviderHTTPClientConfigUnit(t *testing.T) {
config := OIDCProviderHTTPClientConfig()
// Verify OIDC-specific settings
assert.Equal(t, 15*time.Second, config.Timeout, "OIDC provider should have 15s timeout")
assert.Equal(t, 100, config.MaxIdleConns, "OIDC provider should have 100 max idle conns")
assert.Equal(t, 25, config.MaxIdleConnsPerHost, "OIDC provider should have 25 max idle conns per host")
assert.Equal(t, 50, config.MaxConnsPerHost, "OIDC provider should have 50 max conns per host")
assert.Equal(t, 90*time.Second, config.IdleConnTimeout, "OIDC provider should have 90s idle conn timeout")
assert.True(t, config.UseCookieJar, "OIDC provider should have cookie jar enabled")
}
// TestCreateDefaultClientUnit tests CreateDefaultClient function
func TestCreateDefaultClientUnit(t *testing.T) {
factory := NewHTTPClientFactory()
client := factory.CreateDefaultClient()
require.NotNil(t, client)
assert.NotNil(t, client.Transport, "client should have transport")
assert.Equal(t, 10*time.Second, client.Timeout, "default client should have 10s timeout")
}
// TestCreateTokenClientUnit tests CreateTokenClient function
func TestCreateTokenClientUnit(t *testing.T) {
factory := NewHTTPClientFactory()
client := factory.CreateTokenClient()
require.NotNil(t, client)
assert.NotNil(t, client.Transport, "client should have transport")
assert.NotNil(t, client.Jar, "token client should have cookie jar")
assert.Equal(t, 10*time.Second, client.Timeout, "token client should have 10s timeout")
}
// TestCreateHTTPClientWithConfigUnit tests CreateHTTPClientWithConfig function
func TestCreateHTTPClientWithConfigUnit(t *testing.T) {
config := HTTPClientConfig{
Timeout: 5 * time.Second,
MaxIdleConns: 20,
MaxIdleConnsPerHost: 5,
UseCookieJar: true,
}
client := CreateHTTPClientWithConfig(config)
require.NotNil(t, client)
assert.Equal(t, 5*time.Second, client.Timeout)
assert.NotNil(t, client.Jar, "client should have cookie jar when configured")
}
// TestHTTPClientFactoryCreateHTTPClientValidation tests validation in CreateHTTPClient
func TestHTTPClientFactoryCreateHTTPClientValidation(t *testing.T) {
factory := NewHTTPClientFactory()
t.Run("zero values get defaults", func(t *testing.T) {
config := HTTPClientConfig{
// All zero values
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
// Verify defaults were applied
assert.Equal(t, 30*time.Second, client.Timeout)
})
t.Run("custom values preserved", func(t *testing.T) {
config := HTTPClientConfig{
Timeout: 15 * time.Second,
MaxIdleConns: 50,
MaxRedirects: 3,
UseCookieJar: true,
ForceHTTP2: true,
DisableKeepAlives: true,
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
assert.Equal(t, 15*time.Second, client.Timeout)
assert.NotNil(t, client.Jar)
})
t.Run("invalid timeout gets default", func(t *testing.T) {
config := HTTPClientConfig{
Timeout: -1 * time.Second, // Invalid
}
client := factory.CreateHTTPClient(config)
require.NotNil(t, client)
// Should get default due to validation failure
assert.Equal(t, 30*time.Second, client.Timeout)
})
}
// TestHTTPClientFactoryValidateHTTPClientConfig tests ValidateHTTPClientConfig
func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
factory := NewHTTPClientFactory()
tests := []struct {
name string
errorMsg string
config HTTPClientConfig
wantError bool
}{
{
name: "valid config",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: 50,
MaxIdleConnsPerHost: 10,
MaxConnsPerHost: 20,
},
wantError: false,
},
{
name: "negative MaxIdleConns",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: -1,
},
wantError: true,
errorMsg: "MaxIdleConns cannot be negative",
},
{
name: "MaxIdleConns too high",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConns: 1500,
},
wantError: true,
errorMsg: "MaxIdleConns too high",
},
{
name: "negative MaxIdleConnsPerHost",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConnsPerHost: -1,
},
wantError: true,
errorMsg: "MaxIdleConnsPerHost cannot be negative",
},
{
name: "timeout too high",
config: HTTPClientConfig{
Timeout: 10 * time.Minute,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
},
wantError: true,
errorMsg: "timeout too high",
},
{
name: "negative timeout",
config: HTTPClientConfig{
Timeout: -1 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
},
wantError: true,
errorMsg: "timeout must be positive",
},
{
name: "MaxIdleConnsPerHost exceeds MaxConnsPerHost",
config: HTTPClientConfig{
Timeout: 10 * time.Second,
DialTimeout: 5 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
MaxIdleConnsPerHost: 50,
MaxConnsPerHost: 10,
},
wantError: true,
errorMsg: "MaxIdleConnsPerHost (50) cannot exceed MaxConnsPerHost (10)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := factory.ValidateHTTPClientConfig(&tt.config)
if tt.wantError {
assert.Error(t, err)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
+66 -19
View File
@@ -3,6 +3,7 @@ package traefikoidc
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"sync"
@@ -12,19 +13,19 @@ import (
// SharedTransportPool manages a pool of shared HTTP transports to prevent connection exhaustion
type SharedTransportPool struct {
mu sync.RWMutex
transports map[string]*sharedTransport
maxConns int
ctx context.Context
transports map[string]*sharedTransport
cancel context.CancelFunc
clientCount int32 // SECURITY FIX: Track total HTTP clients
maxClients int32 // SECURITY FIX: Limit total clients to 5
maxConns int
mu sync.RWMutex
clientCount int32
maxClients int32
}
type sharedTransport struct {
lastUsed time.Time
transport *http.Transport
refCount int
lastUsed time.Time
}
var (
@@ -103,7 +104,8 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
PreferServerCipherSuites: true,
InsecureSkipVerify: false,
RootCAs: config.RootCAs,
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
},
ForceAttemptHTTP2: config.ForceHTTP2,
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
@@ -146,6 +148,9 @@ func (p *SharedTransportPool) ReleaseTransport(transport *http.Transport) {
}
// cleanupIdleTransports periodically cleans up unused transports
// Uses two-phase cleanup to minimize lock contention:
// 1. Find candidates while holding read lock
// 2. Remove and close transports with minimal lock duration
func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -155,17 +160,46 @@ func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
p.mu.Lock()
now := time.Now()
for transportKey, shared := range p.transports {
// Clean up transports not used for 2 minutes with no references
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(p.transports, transportKey)
// SECURITY FIX: Decrement client count when removing transport
atomic.AddInt32(&p.clientCount, -1)
}
p.performCleanup()
}
}
}
// performCleanup does the actual cleanup with optimized locking
func (p *SharedTransportPool) performCleanup() {
now := time.Now()
// Phase 1: Find candidates while holding read lock (fast)
p.mu.RLock()
candidates := make([]string, 0)
for transportKey, shared := range p.transports {
// Clean up transports not used for 2 minutes with no references
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
candidates = append(candidates, transportKey)
}
}
p.mu.RUnlock()
if len(candidates) == 0 {
return
}
// Phase 2: Remove and close each candidate individually
// This minimizes lock contention and allows concurrent access
for _, key := range candidates {
p.mu.Lock()
shared, exists := p.transports[key]
if exists && shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
// Remove from map first (releases memory)
delete(p.transports, key)
atomic.AddInt32(&p.clientCount, -1)
p.mu.Unlock()
// Close idle connections outside the lock (can be slow)
if shared.transport != nil {
shared.transport.CloseIdleConnections()
}
} else {
p.mu.Unlock()
}
}
@@ -173,8 +207,21 @@ func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
// configKey generates a unique key for a config
func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
// Simple key based on main parameters
return string(rune(config.MaxConnsPerHost)) + string(rune(config.MaxIdleConnsPerHost))
// Pool transports by the parameters that change TLS or connection
// behavior. RootCAs and InsecureSkipVerify MUST be part of the key:
// otherwise a middleware configured with a custom CA would share a
// transport with one using the system store, silently bypassing its
// CA configuration.
skip := "0"
if config.InsecureSkipVerify {
skip = "1"
}
return fmt.Sprintf("%d|%d|%p|%s",
config.MaxConnsPerHost,
config.MaxIdleConnsPerHost,
config.RootCAs,
skip,
)
}
// Cleanup closes all transports and stops the cleanup goroutine
+691
View File
@@ -0,0 +1,691 @@
package traefikoidc
import (
"context"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSharedTransportPoolGetOrCreateTransport tests transport creation and reuse
func TestSharedTransportPoolGetOrCreateTransport(t *testing.T) {
t.Run("create new transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount))
assert.Len(t, pool.transports, 1)
})
t.Run("reuse existing transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport1 := pool.GetOrCreateTransport(config)
transport2 := pool.GetOrCreateTransport(config)
assert.Equal(t, transport1, transport2, "should reuse same transport")
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount), "client count should not increase")
// Check ref count
pool.mu.RLock()
key := pool.configKey(config)
shared := pool.transports[key]
pool.mu.RUnlock()
assert.Equal(t, 2, shared.refCount, "ref count should be 2")
})
t.Run("client limit enforcement", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 5, // Already at max
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
assert.Nil(t, transport, "should return nil when at client limit")
})
t.Run("client limit with existing transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
// Create first transport
config1 := DefaultHTTPClientConfig()
transport1 := pool.GetOrCreateTransport(config1)
require.NotNil(t, transport1)
// Set client count to max
atomic.StoreInt32(&pool.clientCount, 5)
// Try to create with different config
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 15 // Different config
transport2 := pool.GetOrCreateTransport(config2)
// Should return existing transport since at limit
assert.NotNil(t, transport2)
assert.Equal(t, transport1, transport2)
})
t.Run("updates last used time", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.mu.RLock()
key := pool.configKey(config)
firstTime := pool.transports[key].lastUsed
pool.mu.RUnlock()
time.Sleep(10 * time.Millisecond)
// Get again
transport2 := pool.GetOrCreateTransport(config)
require.NotNil(t, transport2)
pool.mu.RLock()
secondTime := pool.transports[key].lastUsed
pool.mu.RUnlock()
assert.True(t, secondTime.After(firstTime), "lastUsed should be updated")
})
}
// TestSharedTransportPoolReleaseTransport tests transport release
func TestSharedTransportPoolReleaseTransport(t *testing.T) {
t.Run("decrement ref count", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Get again to increase ref count
pool.GetOrCreateTransport(config)
pool.mu.RLock()
key := pool.configKey(config)
refCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 2, refCount)
// Release
pool.ReleaseTransport(transport)
pool.mu.RLock()
newRefCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 1, newRefCount)
})
t.Run("ref count reaches zero", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.mu.RLock()
key := pool.configKey(config)
pool.mu.RUnlock()
// Release to zero
pool.ReleaseTransport(transport)
pool.mu.RLock()
shared := pool.transports[key]
pool.mu.RUnlock()
assert.Equal(t, 0, shared.refCount)
assert.NotZero(t, shared.lastUsed, "lastUsed should be set")
})
t.Run("release non-existent transport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
// Create a transport not in the pool
fakeTransport := &http.Transport{}
// Should not panic
assert.NotPanics(t, func() {
pool.ReleaseTransport(fakeTransport)
})
})
t.Run("release updates last used", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
time.Sleep(10 * time.Millisecond)
beforeRelease := time.Now()
pool.ReleaseTransport(transport)
pool.mu.RLock()
key := pool.configKey(config)
lastUsed := pool.transports[key].lastUsed
pool.mu.RUnlock()
assert.True(t, lastUsed.After(beforeRelease) || lastUsed.Equal(beforeRelease))
})
}
// TestSharedTransportPoolCleanup tests cleanup functionality
func TestSharedTransportPoolCleanup(t *testing.T) {
t.Run("cleanup all transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create multiple transports
config1 := DefaultHTTPClientConfig()
pool.GetOrCreateTransport(config1)
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 15
pool.GetOrCreateTransport(config2)
assert.Greater(t, len(pool.transports), 0)
// Cleanup
pool.Cleanup()
assert.Len(t, pool.transports, 0, "all transports should be removed")
})
t.Run("cleanup cancels context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
pool.Cleanup()
select {
case <-pool.ctx.Done():
// Context was canceled
case <-time.After(100 * time.Millisecond):
t.Error("context should be canceled")
}
})
t.Run("cleanup with no transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
assert.NotPanics(t, func() {
pool.Cleanup()
})
})
t.Run("cleanup closes idle connections", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Cleanup should call CloseIdleConnections on each transport
pool.Cleanup()
// Verify transports map is cleared
assert.Empty(t, pool.transports)
})
}
// TestSharedTransportPoolCleanupIdleTransports tests periodic cleanup
func TestSharedTransportPoolCleanupIdleTransports(t *testing.T) {
if testing.Short() {
t.Skip("Skipping cleanup goroutine test in short mode")
}
t.Run("cleanup removes idle transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create transport and release it
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
pool.ReleaseTransport(transport)
// Set lastUsed to old time
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Start cleanup in background (simulating what would happen)
// Note: We're testing the cleanup logic manually here
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
atomic.AddInt32(&pool.clientCount, -1)
}
}
pool.mu.Unlock()
// Transport should be removed
pool.mu.RLock()
_, exists := pool.transports[key]
pool.mu.RUnlock()
assert.False(t, exists, "old idle transport should be removed")
})
t.Run("cleanup preserves active transports", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Create transport with refs
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
// Keep ref count > 0, but set old lastUsed
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Run cleanup logic
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
}
}
pool.mu.Unlock()
// Transport should still exist (has ref count)
pool.mu.RLock()
_, exists := pool.transports[key]
pool.mu.RUnlock()
assert.True(t, exists, "transport with references should be preserved")
})
t.Run("cleanup respects context cancellation", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
// Start cleanup goroutine
done := make(chan bool)
go func() {
pool.cleanupIdleTransports(ctx)
done <- true
}()
// Cancel context
cancel()
// Should exit quickly
select {
case <-done:
// Success
case <-time.After(2 * time.Second):
t.Error("cleanup goroutine should exit on context cancellation")
}
})
}
// TestCreatePooledHTTPClient tests pooled client creation
func TestCreatePooledHTTPClient(t *testing.T) {
t.Run("create client with default config", func(t *testing.T) {
config := DefaultHTTPClientConfig()
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
assert.NotNil(t, client.Transport)
assert.Equal(t, config.Timeout, client.Timeout)
})
t.Run("create multiple clients reuse transport", func(t *testing.T) {
// Reset global pool for clean test
globalTransportPoolOnce = sync.Once{}
globalTransportPool = nil
config := DefaultHTTPClientConfig()
client1 := CreatePooledHTTPClient(config)
client2 := CreatePooledHTTPClient(config)
require.NotNil(t, client1)
require.NotNil(t, client2)
// Should use same transport
assert.Equal(t, client1.Transport, client2.Transport)
})
t.Run("redirect policy is set", func(t *testing.T) {
config := DefaultHTTPClientConfig()
config.MaxRedirects = 3
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
assert.NotNil(t, client.CheckRedirect)
// Test redirect limit
var redirects []*http.Request
for i := 0; i < 3; i++ {
redirects = append(redirects, &http.Request{})
}
err := client.CheckRedirect(nil, redirects)
assert.Error(t, err, "should error after max redirects")
})
t.Run("default redirect limit", func(t *testing.T) {
config := DefaultHTTPClientConfig()
config.MaxRedirects = 0 // Should default to 10
client := CreatePooledHTTPClient(config)
require.NotNil(t, client)
// Test default redirect limit (10)
var redirects []*http.Request
for i := 0; i < 10; i++ {
redirects = append(redirects, &http.Request{})
}
err := client.CheckRedirect(nil, redirects)
assert.Error(t, err, "should error after 10 redirects")
})
}
// TestGetGlobalTransportPool tests singleton pattern
func TestGetGlobalTransportPool(t *testing.T) {
t.Run("returns same instance", func(t *testing.T) {
pool1 := GetGlobalTransportPool()
pool2 := GetGlobalTransportPool()
assert.Equal(t, pool1, pool2, "should return same singleton instance")
})
t.Run("pool is initialized", func(t *testing.T) {
pool := GetGlobalTransportPool()
require.NotNil(t, pool)
assert.NotNil(t, pool.transports)
assert.Equal(t, 20, pool.maxConns)
assert.Equal(t, int32(5), pool.maxClients)
assert.NotNil(t, pool.ctx)
assert.NotNil(t, pool.cancel)
})
}
// TestSharedTransportPoolConcurrency tests thread safety
func TestSharedTransportPoolConcurrency(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrency test in short mode")
}
t.Run("concurrent GetOrCreateTransport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 10, // Allow more for concurrency test
}
config := DefaultHTTPClientConfig()
const numGoroutines = 20
var wg sync.WaitGroup
transports := make([]*http.Transport, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
transports[idx] = pool.GetOrCreateTransport(config)
}(i)
}
wg.Wait()
// All should get same transport
firstTransport := transports[0]
for i := 1; i < numGoroutines; i++ {
if transports[i] != nil {
assert.Equal(t, firstTransport, transports[i])
}
}
})
t.Run("concurrent ReleaseTransport", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 10,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
// Increase ref count
for i := 0; i < 20; i++ {
pool.GetOrCreateTransport(config)
}
const numReleases = 20
var wg sync.WaitGroup
for i := 0; i < numReleases; i++ {
wg.Add(1)
go func() {
defer wg.Done()
pool.ReleaseTransport(transport)
}()
}
wg.Wait()
// Should not panic and ref count should be decremented
pool.mu.RLock()
key := pool.configKey(config)
refCount := pool.transports[key].refCount
pool.mu.RUnlock()
assert.Equal(t, 1, refCount, "ref count should be 1 after 20 releases from initial 21")
})
}
// TestSharedTransportPoolEdgeCases tests edge cases
func TestSharedTransportPoolEdgeCases(t *testing.T) {
t.Run("config key generation", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
}
config1 := DefaultHTTPClientConfig()
config1.MaxConnsPerHost = 10
config1.MaxIdleConnsPerHost = 5
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 10
config2.MaxIdleConnsPerHost = 5
key1 := pool.configKey(config1)
key2 := pool.configKey(config2)
assert.Equal(t, key1, key2, "same config should produce same key")
})
t.Run("different configs produce different keys", func(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
}
config1 := DefaultHTTPClientConfig()
config1.MaxConnsPerHost = 10
config2 := DefaultHTTPClientConfig()
config2.MaxConnsPerHost = 20
key1 := pool.configKey(config1)
key2 := pool.configKey(config2)
assert.NotEqual(t, key1, key2, "different configs should produce different keys")
})
t.Run("client count decrements on cleanup", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
clientCount: 0,
maxClients: 5,
ctx: ctx,
cancel: cancel,
}
config := DefaultHTTPClientConfig()
transport := pool.GetOrCreateTransport(config)
require.NotNil(t, transport)
initialCount := atomic.LoadInt32(&pool.clientCount)
assert.Equal(t, int32(1), initialCount)
// Release and mark as old
pool.ReleaseTransport(transport)
pool.mu.Lock()
key := pool.configKey(config)
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
pool.mu.Unlock()
// Run cleanup
pool.mu.Lock()
now := time.Now()
for transportKey, shared := range pool.transports {
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
shared.transport.CloseIdleConnections()
delete(pool.transports, transportKey)
atomic.AddInt32(&pool.clientCount, -1)
}
}
pool.mu.Unlock()
finalCount := atomic.LoadInt32(&pool.clientCount)
assert.Equal(t, int32(0), finalCount, "client count should decrement on cleanup")
})
}

Some files were not shown because too many files have changed in this diff Show More