Compare commits

...

21 Commits

Author SHA1 Message Date
lukaszraczylo bfd702a447 fix(jwk): keep parsed JWKS in local cache only (#134) (#136)
Under yaegi (Traefik's plugin runtime) json.Marshal exposes unexported
struct fields with an X-prefixed name. parsedJWKS{ keys map[string]
crypto.PublicKey } therefore round-tripped through Redis as
{"Xkeys":{"<kid>":{"N":<huge>,"E":65537}}} — *rsa.PublicKey.N is a
*big.Int that marshals to a JSON number hundreds of digits long. On
read, json.Unmarshal into interface{} parses numbers as float64, which
cannot represent that range:

  Failed to deserialize value for key .../discovery/v2.0/keys:parsed:
  json: cannot unmarshal number 2251513...
    into Go value of type float64

Auth still worked (the JWKCache rebuilt the keys in memory on every
miss) but the error log spammed every request.

Two structural problems were behind it:

* parsedJWKS holds crypto.PublicKey interface values that aren't
  meaningfully JSON-serializable. Even on compiled Go (where the
  unexported field marshals to {}), the post-roundtrip type assertion
  v.(*parsedJWKS) silently failed and the cache was useless.
* The same pattern applied to *JWKSet — the struct shape survived JSON
  but the type assertion still failed, defeating the cache for every
  call that went through Redis.

Both keys now use the new UniversalCache.SetLocal/GetLocal pair, which
skips the configured distributed backend entirely. JWK rotation is rare
and a per-replica HTTP fetch on cold cache is cheap, so cross-replica
coherence buys nothing for these entries.

Stale Redis entries written by previous versions are simply ignored —
the new code never reads under those keys, and Redis TTL retires them.

Includes regression coverage for the Azure round-trip, the
poisoned-stale-data scenario, and the SetLocal/GetLocal isolation
contract.

patch-release
2026-05-08 13:35:23 +01:00
lukaszraczylo 68c150eba4 fix(cache/redis): honor enableTLS for Redis backend (#133)
The redis.enableTLS / redis.tlsSkipVerify settings were accepted by the
config layer but silently dropped before reaching the connection pool, so
the plugin always dialed Redis in plaintext. This blocked TLS-only Redis
deployments such as AWS ElastiCache with in-transit encryption.

- Add EnableTLS, TLSSkipVerify, TLSServerName to backends.Config and
  PoolConfig and forward them through universal_cache_singleton ->
  backends.Config -> PoolConfig.
- In the connection pool, dial via tls.Dialer.DialContext (TLS 1.2
  minimum) with SNI defaulting to the host part of the configured
  Address when TLSServerName is empty, so ElastiCache cluster endpoints
  validate out of the box. Plain dial path now also propagates ctx.
- Add regression tests covering successful TLS negotiation with skip-
  verify, rejection of self-signed certs without skip-verify, rejection
  of plain TCP servers when EnableTLS=true, and unaffected plaintext
  behavior.
- Document maxRefreshTokenAgeSeconds (added in 1b6c861) and the implicit
  SSE / WebSocket auth bypass (added in 684a990) in README.md,
  docs/CONFIGURATION.md and docs/index.html.
- Add the missing redis.tlsSkipVerify row to docs/index.html and clarify
  the redis.enableTLS description.

patch-release
2026-05-07 12:24:13 +01:00
lukaszraczylo 9cbca4c4fb fix(refresh): honor userIdentifierClaim in token refresh path (#132)
patch-release

The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.

The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.

Mirror the callback logic in refreshToken so both paths behave the same.

Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").

Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
2026-05-07 09:21:41 +01:00
lukaszraczylo 684a990f59 fix: reduce yaegi CPU footprint + require auth on SSE/WebSocket bypass
minor-release

Behaviour changes (potentially breaking for operators relying on the prior
unauthenticated SSE bypass):

* SSE (`Accept: text/event-stream`) and WebSocket upgrade requests now
  return 401 when no authenticated session is present. Previously the
  bypass forwarded unconditionally, which let any caller reach the
  backend by setting the right header. Excluded URLs are unchanged.
  Operators relying on unauthenticated SSE/WS access must move the path
  into ExcludedURLs.

Performance fixes (target: long-running dashboards like Grafana / ArgoCD
where many panels poll concurrently while the page stays open):

* Stop honouring isTestMode() for the singleton-token-cleanup interval
  under yaegi (the Traefik plugin runtime). In production the plugin was
  running a 20 Hz no-op cleanup ticker because runtime.Compiler ==
  "yaegi" tripped the test-mode branch.
* processAuthorizedRequest now resolves ID-token claims at most once per
  request via SessionData.GetIDTokenClaims (already cached on the
  session) and reuses them for both groups/roles extraction and
  header-template rendering. Previously every authenticated request
  parsed the JWT twice.
* Added extractGroupsAndRolesFromClaims to drive groups/roles off
  pre-parsed claims; extractGroupsAndRoles still works for tests.
* Removed the unconditional session.MarkDirty() in the header-templates
  branch. Templates only mutate request headers, not session state, so
  the prior MarkDirty was re-encrypting and rewriting all session
  cookies on every authenticated request that used header templates.

Other:

* Added isWebSocketUpgrade (RFC 6455 handshake detection — Connection:
  Upgrade + Upgrade: websocket, tolerant of multi-token Connection
  headers and case).
* Renamed applySSEUserHeaders -> applyBypassUserHeaders; it now returns
  bool so the dispatcher can reject unauthenticated SSE/WS with 401.
* Added tests for SSE and WS bypass covering both the auth-rejection
  path and the authenticated forward path.
2026-05-02 03:12:20 +01:00
lukaszraczylo 1b6c8616fd fix(refresh): coalesce refresh-token grants + bound goroutines + cache hot path (target v0.8.27) (#131)
* fix(refresh): wire RefreshCoordinator into the live refresh path

The RefreshCoordinator existed but was never instantiated. The actual
refresh path used only session.refreshMutex, which is per-SessionData
instance - and SessionData is pulled from a sync.Pool per request -
so concurrent requests sharing a refresh token had ZERO coordination.

Symptom: when access_token expired (e.g. 5min Zitadel default), every
in-flight request from a polling client (Grafana panels) entered the
refresh path simultaneously and POSTed the same refresh_token to the
IdP. With refresh-token rotation enabled (Zitadel/Authentik default),
only one grant succeeded; the rest got invalid_grant and each cleared
the entire session. Subsequent requests then thrashed in re-auth loops.

This commit:
- adds refreshCoordinator field on TraefikOidc
- instantiates it in NewWithContext with DefaultRefreshCoordinatorConfig
- shuts it down in Close() under shutdownOnce
- routes refreshToken() through the coordinator via coordinatedTokenRefresh,
  which collapses concurrent grants to a single upstream call per
  refresh_token hash
- exports refreshCoordinatorSessionID for both internal hashing and the
  middleware-level wireup so dedup keys stay aligned

Behavioural notes:
- nil-coordinator fallback preserves existing tests that build TraefikOidc
  literals without going through the constructor
- followers receive the same TokenResponse/error as the leader, so no
  per-instance code paths change
- existing TestGetNewTokenWithRefreshToken_Concurrency still passes
  because it hits GetNewTokenWithRefreshToken directly, below the
  coordinator boundary

Tests:
- refresh_coordinator_wireup_test.go: 50 concurrent refreshes coalesce
  to <=2 upstream calls; distinct tokens still run in parallel; nil
  coordinator falls back cleanly

* perf(cache): bound L1 backfill goroutines in HybridBackend

Get() and GetMany() previously spawned a goroutine per L2 hit to write
the value through to L1. Under sustained polling traffic (e.g. a Grafana
dashboard refreshing every 30s with N panels) this minted thousands of
goroutines, each running in Yaegi - directly contributing to the
~1000% CPU spike that pairs with the refresh-token herd.

Replace the per-hit goroutines with a single l1BackfillWorker fed by
l1BackfillBuffer, mirroring the existing asyncWriteBuffer/asyncWriteWorker
pattern for L2 writes. Buffer overflow drops the backfill (counted via
l1BackfillDrops) - a dropped backfill just means the next L2 hit for
that key re-queues it, which is safe.

Tests:
- TestHybridBackend_L1BackfillBounded: 1000 distinct L2 hits keep
  goroutine count within +20 of baseline (pre-fix it grew by ~1000)
- TestHybridBackend_L1BackfillFullDrops: drops are accounted for when
  the buffer is saturated and the worker is stopped

* feat(refresh): implement isRefreshTokenExpired heuristic

Replace the placeholder `return false` with a real check based on the
issued_at timestamp that SetRefreshToken already stamps into the session.
Gated by a new MaxRefreshTokenAgeSeconds config field (default 21600 =
6h, matching the existing comment). 0 disables the check.

This wires the previously-dead refreshTokenExpired branch in middleware.go,
which short-circuits AJAX requests with a 401 instead of letting them
hammer the IdP for a refresh token that's almost certainly stale - the
classic Grafana-after-long-pause failure mode.

Behaviour:
- maxRefreshTokenAge=0 disables the check (preserves prior behaviour)
- legacy sessions without issued_at still attempt one refresh; the IdP
  remains the source of truth on first try
- nil-receiver and nil-session guards keep test code that builds
  TraefikOidc literals safe

Tests:
- TestIsRefreshTokenExpired_DisabledWhenAgeZero
- TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp
- TestIsRefreshTokenExpired_WithinWindow
- TestIsRefreshTokenExpired_BeyondWindow
- TestIsRefreshTokenExpired_NilGuards

* perf(token): skip parseJWT on cache hit in VerifyToken

The token cache fast-return existed but ran AFTER parseJWT, so every
validation paid for base64 + JSON unmarshal even on a hit. Under bursty
traffic (e.g. 10+ concurrent panel requests on every Grafana dashboard
refresh, each calling validateStandardTokens which verifies BOTH the
access token and the ID token), this is two redundant parses per
request multiplied by the panel count.

Move the cache lookup ahead of parseJWT. On a hit the function returns
nil immediately. On a miss the original flow runs unchanged.

Also nil-guard t.tokenCache to keep partial-literal test instances safe
(matches the same pattern we already use for tokenBlacklist).

Tests:
- TestVerifyToken_CacheHitSkipsParse: cache pre-populated with claims
  for a token whose body would fail parseJWT - returns nil iff the
  fast-path bypasses the parse
- TestVerifyToken_CacheMissStillParses: a syntactically valid but
  unsigned token still errors past parseJWT on cache miss

* feat(refresh): cross-replica refresh-grant dedup via shared cache

The in-process RefreshCoordinator added in 9f96d8c already collapses
concurrent refresh-token grants on a single Traefik replica. With the
plugin's existing Redis (Dragonfly) cache infrastructure available, we
can extend that dedup across replicas: if pod A refreshes a token at
T+0 and pod B receives a request for the same session at T+1, pod B
should reuse pod A's result rather than POSTing the now-rotated refresh
token to the IdP.

Implementation:
- Add a refreshResultCache to UniversalCacheManager (memory-only when
  Redis is disabled, Redis-backed in production via the existing
  hybrid/Redis-only mode selection)
- Expose it through CacheManager.GetSharedRefreshResultCache and on the
  TraefikOidc struct as refreshResultCache (CacheInterface)
- Inside the closure passed to RefreshCoordinator.CoordinateRefresh,
  consult the cache first; on hit return immediately, on miss exchange
  with the IdP and populate the cache for peers
- 5s TTL: long enough for siblings to observe, short enough that a
  rotated refresh token cannot be re-supplied after the IdP has moved on
- Errors are intentionally NOT cached - peers must always be able to
  retry on their own

Pragmatic choice: optimistic cache rather than a hard distributed lock.
- A hard lock (SET NX + poll) doubles Redis RTT and risks dead-locks
  if a Traefik pod dies mid-grant.
- The user's BGP+Local externalTrafficPolicy already pins ingress for
  a session to one node in steady state, so cross-pod racing is rare.
- This optimistic path catches the rare failover case without adding
  failure modes.

Tests:
- TestCoordinatedTokenRefresh_CrossReplicaCacheHit: pre-populated cache
  short-circuits the upstream call entirely (0 IdP calls)
- TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache: leader stores
  a successful result for peers to find
- TestCoordinatedTokenRefresh_ErrorIsNotCached: invalid_grant must not
  poison the dedup cache - peers must retry independently
2026-04-30 18:52:39 +01:00
lukaszraczylo 4d28fa01ab perf(jwk,cache): cache parsed public keys + RLock token cache reads
Hot-path JWT verification rebuilt the public key on every call:
  jwk -> ToRSAPublicKey -> x509.MarshalPKIXPublicKey -> pem.Encode
  -> verifySignature -> pem.Decode -> x509.ParsePKIXPublicKey -> verify
Under yaegi this pinned a CPU when many concurrent dashboard panels
poll behind the middleware. The PEM round trip is pure waste.

* jwk.go: cache pre-parsed crypto.PublicKey per kid alongside the
  raw JWKSet (parallel cache entry, same 1h TTL, invalidates together).
* jwt.go: split verifySignatureWithKey from verifySignature; existing
  PEM-input entry point preserved for backchannel-logout callers.
* token_manager.go: VerifyJWTSignatureAndClaims now goes straight from
  jwks cache to verifySignatureWithKey, no PEM round trip and no
  per-request availableKids slice.
* universal_cache.go: token/JWK/session Get() takes RLock when the
  entry is unexpired, so concurrent token verifications no longer
  serialize on a single mutex. LRU semantics for general and metadata
  caches are unchanged (tests cover the strict-LRU contract there).
* mocks: MockJWKCache, EnhancedMockJWKCache, mockJWKCacheForLogout,
  staticJWKCache satisfy the extended interface.
2026-04-30 10:14:10 +01:00
lukaszraczylo 2d1b04c637 review fixes apr 2026 (#130)
* Multiple fixes

- refresh coordinator dedup + memory pressure wire
- middleware sse consolidation + timer leak + claim cache
- universal cache sync backfill + isDebug gate
- lazy background task race
- memory monitor stw cached + refresh() api

* fix(auth): suppress OIDC redirects on non-navigation requests

- [x] Add isNonNavigationRequest using Sec-Fetch-Mode and Accept headers
- [x] Add comprehensive TestIsNonNavigationRequest
- [x] Update ServeHTTP to 401 non-navigation and AJAX requests

Fixes #129

* feat(config): add custom CA and insecure skip verify for OIDC TLS

- [x] Add CACertPath, CACertPEM, InsecureSkipVerify to Config
- [x] Implement loadCACertPool for CA bundle loading
- [x] Update HTTPClientConfig with RootCAs and InsecureSkipVerify
- [x] Apply CA pool and skip verify to pooled HTTP clients
- [x] Enhance configKey to distinguish TLS configs
- [x] Add comprehensive ca_cert_test.go

Fixes #125

* feat(oidc): add custom CA certificate support for private OIDC providers

- [x] Add caCertPath, caCertPEM, insecureSkipVerify config options
- [x] Update traefik.yml with new OIDC client config fields
- [x] Add configuration schema descriptions for new options
- [x] Update README table and add Custom CA Certificates section

* Fix the documentation.

* test(redis): add oversized argument rejection test

- [x] Add TestRedisConn_RejectOversizedArgumentBytes
- [x] Import strings package

* Dependencies cleanup
2026-04-19 10:12:00 +01:00
lukaszraczylo ccbb98b9dd fix-issue-122 (#128) 2026-03-04 00:23:30 +00:00
Serhii Vasyliev 1362cc0dac Improve debug logging around callback URL matching (#126)
* Add debug logging around callback URL matching in ServeHTTP

The callback URL comparison at the core of OIDC flow had zero logging,
making it extremely difficult to diagnose redirect loop issues caused
by misconfigured callbackURL (e.g., full URL vs path-only).

Every other path comparison in ServeHTTP already logs debug info
(logout, backchannel, frontchannel, excluded URLs), but the callback
URL check was completely silent.

Added debug logs that show:
- The values being compared (request path vs configured callback)
- Whether the match succeeded or failed
- Configured redirURLPath during initialization

This would have immediately revealed the root cause of issue #1
where callbackURL was set as a full URL but compared against
req.URL.Path which only contains the path component.

Closes #3

* improve-callback-url-logging: Add init-time logging for callbackURL config
2026-02-23 10:36:37 +00:00
Yuval Bar-On 249dcad1b3 fix: prevent deadlock in SessionData.Clear method (#114)
Move mutex unlock before calling Save() to prevent potential deadlock
when Save() method needs to acquire the same mutex.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-16 15:02:33 +00:00
lukaszraczylo de4b4d7258 fix(cache): remove sync.Pool for Yaegi compatibility (#121)
- [x] Remove sync.Pool implementation that causes reflection panics
- [x] Replace pool-based NewRESPWriter with direct instantiation
- [x] Replace pool-based NewRESPReader with direct instantiation
- [x] Convert Release() methods to no-ops for API compatibility
- [x] Add documentation explaining sync.Pool removal for Yaegi
- [x] Remove "sync" import

Resolves #120
2026-01-19 17:52:31 +00:00
lukaszraczylo 9d52f1b018 feat(core): refactor linters config and improve code quality (#119)
- [x] Reorganize golangci-lint configuration with documented disable reasons
- [x] Simplify errcheck and revive linter rules with targeted exclusions
- [x] Pre-compile regex patterns in input_validation.go for performance
- [x] Fix type assertions in memory_shard.go and resp.go with safety checks
- [x] Replace string comparison with EqualFold for case-insensitive matching
- [x] Fix loop variable captures in jwk.go and logout.go
- [x] Change high goroutine log level from Info to Debug in autocleanup.go
- [x] Replace deprecated "cancelled" spelling with "canceled" throughout
- [x] Add nolint annotations for intentional unused parameters
- [x] Improve comment formatting for deprecated functions
- [x] Fix comment spelling: "marshalling" → "marshaling"
- [x] Refactor provider warnings formatting in internal/providers/warnings.go
- [x] Simplify metrics summary building in internal/recovery/metrics.go
- [x] Pre-allocate slice in error_recovery.go GetDegradedServices
- [x] Refactor context cancellation checks in redis.go
2026-01-15 10:40:49 +00:00
lukaszraczylo 57724918fe fix 116 (#118)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases

* docs: add security fix documentation for integer overflow protection

* test: fix goroutine tests to use mock OIDC servers

The TestContextAwareGoroutineManagement tests were making real HTTP
calls to hardcoded URLs like https://example.com, causing failures
in CI when those requests timeout or return HTTP errors.

Changes:
- Added createMockOIDCServer() helper function using httptest
- Updated GoroutineCleanupOnContextCancel to use mock server
- Updated NoGoroutineLeakOnMultipleInstances to use 3 mock servers
- Updated SingletonTasksAcrossInstances to use mock servers array

This prevents network calls and makes tests more reliable and faster.

Fixes test failures in GitHub Actions CI.
2026-01-08 22:50:46 +00:00
lukaszraczylo 775de2ada1 Fix cache serialisation (#117)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases
2026-01-08 22:06:19 +00:00
lukaszraczylo 7816e05c98 fix issue with logout url (#112)
* fix(logout): handle logout requests before OIDC initialization

- [x] Add debug logging to logout handler entry point
- [x] Move logout path check before OIDC initialization to enable logout when provider unavailable
- [x] Move excluded URL and SSE checks before initialization wait
- [x] Add debug logging for initialization wait to diagnose hanging requests
- [x] Add test for logout functionality without OIDC provider availability

* feat(logout): implement OIDC backchannel and front-channel logout

- [x] Add logout token validation and backchannel logout handler
- [x] Add front-channel logout handler with iframe support
- [x] Implement session invalidation cache for distributed deployments
- [x] Add comprehensive logout token claim verification (issuer, audience, events, iat, sid/sub)
- [x] Integrate session invalidation checks into authorization flow
- [x] Add configuration options for enabling backchannel/front-channel logout
- [x] Add extensive test coverage for logout flows and edge cases
- [x] Update documentation with logout configuration examples
- [x] Add middleware routing for logout endpoints
- [x] Extend cache manager with session invalidation cache support

Resolves #110

* fixup! feat(logout): implement OIDC backchannel and front-channel logout

* fixup! Merge branch 'main' into fix-issue-with-logout-url
2026-01-04 01:59:50 +00:00
Dominik Chilla 8bf7998150 Fix for Hashicorp Vault - accept opaque access tokens with dot-characters (#113) 2026-01-02 16:42:22 +00:00
muffn_ 22c4323fcb fix: set X-Forwarded-User header for SSE requests from existing session (#111)
Co-authored-by: muffin <MonsterMuffin@users.noreply.github.com>
2026-01-02 02:50:11 +00:00
lukaszraczylo 06b219d1f8 feat(dcr): Add Redis storage support for multi-replica deployments (#109)
- [x] Add file and Redis storage backends for DCR credentials
- [x] Implement storage abstraction with FileStore and RedisStore
- [x] Add factory function for automatic backend selection (auto/file/redis)
- [x] Integrate DCR credentials cache into UniversalCacheManager
- [x] Add comprehensive tests for storage backends and factory
- [x] Update configuration schema with storage backend options
- [x] Update documentation with multi-replica deployment guidance
- [x] Add Redis key prefix configuration for credential isolation
2025-12-31 12:52:39 +00:00
lukaszraczylo 413e4a1b7d LRU + cache conflicts prevention. (#104)
* LRU + cache conflicts prevention.

* Bugfix universalCache flooding ( issue #105 )

  1. Traefik cancels the context for old plugin instances
  2. Each plugin's Close() method is called
  3. The CacheInterfaceWrapper.Close() was calling cache.Close() on the shared singleton caches
  4. Each Close() triggered Clear() which logged "Cleared all items" at INFO level
2025-12-24 18:54:39 +00:00
lukaszraczylo 69e0d98c67 fixup! Add signing of the plugin on release. 2025-12-24 12:33:33 +00:00
lukaszraczylo 6d893df12b Add signing of the plugin on release. 2025-12-15 00:38:35 +00:00
142 changed files with 11691 additions and 6433 deletions
+2
View File
@@ -11,7 +11,9 @@ on:
workflow_dispatch:
permissions:
id-token: write
contents: write
packages: write
jobs:
release:
+1
View File
@@ -1,3 +1,4 @@
docker/
.claude/*.out
*.test
.leann/
+49 -32
View File
@@ -14,21 +14,22 @@ linters:
- gosec
- misspell
- noctx
- nolintlint
- prealloc
- revive
- rowserrcheck
- sqlclosecheck
- unconvert
- unparam
- whitespace
disable:
- exhaustive
- funlen
- gocognit
- gocyclo # Disabled: OAuth/OIDC flows are inherently complex
- goprintffuncname # Disabled: naming convention is project-specific
- lll
- mnd
- testpackage
- whitespace # Disabled: style preference about newlines
- wsl
settings:
dupl:
@@ -47,29 +48,13 @@ linters:
- fmt.Fprintln
goconst:
min-len: 3
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
min-occurrences: 15 # Increased to reduce noise for standard OAuth2/OIDC strings and common patterns like "true"
ignore-tests: true
gocritic:
# Using default enabled checks in v2
enabled-checks:
- appendCombine
- boolExprSimplify
- builtinShadow
- commentedOutCode
- emptyFallthrough
- equalFold
- hexLiteral
- indexAlloc
- initClause
- methodExprCall
- nestingReduce
- rangeExprCopy
- rangeValCopy
- stringXbytes
- typeAssertChain
- typeUnparen
- unlabelStmt
- yodaStyleExpr
# Disable style-only checks that add noise
disabled-checks:
- ifElseChain # Style preference, switch not always clearer
- elseif # Style preference
gocyclo:
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
gosec:
@@ -106,23 +91,23 @@ linters:
- name: error-return
- name: error-strings
- name: error-naming
- name: exported
- name: if-return
# - name: exported # Disabled: too noisy, not all exported functions need comments
# - name: if-return # Disabled: style preference
- name: increment-decrement
- name: var-naming
- name: var-declaration
- name: package-comments
# - name: var-naming # Disabled: too strict for legacy code (IP vs Ip)
# - name: var-declaration # Disabled: explicit zero values can be clearer
# - name: package-comments # Disabled: handled by other tools
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
- name: indent-error-flow
# - name: indent-error-flow # Disabled: style preference
- name: errorf
- name: empty-block
# - name: empty-block # Disabled: sometimes empty blocks are intentional
- name: superfluous-else
- name: unused-parameter
# - name: unused-parameter # Disabled: test callbacks and interface implementations often have required unused params
- name: unreachable-code
- name: redefines-builtin-id
# - name: redefines-builtin-id # Disabled: min/max helpers are common before Go 1.21
unparam:
check-exported: false
staticcheck:
@@ -132,8 +117,15 @@ linters:
- -QF1003 # Tagged switch - style preference, may affect Yaegi
- -QF1007 # Merge conditional assignment - style preference
- -QF1008 # Remove embedded field - may break Yaegi compatibility
- -QF1011 # Omit type from declaration - style preference
- -QF1012 # Use fmt.Fprintf - style preference
- -SA9003 # Empty branch - sometimes intentional for future work
- -ST1000 # Package comment format - not required for all packages
- -ST1003 # Package name format - allowed for test packages
- -ST1016 # Receiver name consistency - legacy code
- -ST1020 # Comment format for methods - style preference
- -ST1021 # Comment format for types - style preference
- -ST1023 # Omit type from declaration - style preference
exclusions:
generated: lax
rules:
@@ -144,18 +136,43 @@ linters:
- goconst
- gocyclo
- gosec
- govet
- ineffassign
- noctx
- prealloc
- unparam
- revive
- gocritic
path: _test\.go
- linters:
- dupl
- gocyclo
- govet
- noctx
- prealloc
- unparam
- revive
- gocritic
path: test.*\.go
- linters:
- gocritic
- unused
- errcheck
- revive
path: mocks.*\.go
- linters:
- errcheck
- revive
- gocritic
- govet
- unparam
path: internal/testutil/
- linters:
- govet
- unparam
- noctx
- prealloc
path: integration/
- linters:
- gosec
text: 'G404:'
+11
View File
@@ -47,3 +47,14 @@ release:
name_template: "v{{ .Version }}"
draft: false
prerelease: auto
signs:
- cmd: cosign
signature: "${artifact}.sigstore.json"
args:
- sign-blob
- "--bundle=${signature}"
- "${artifact}"
- "--yes"
artifacts: checksum
output: true
+47 -1610
View File
File diff suppressed because it is too large Load Diff
+227 -1952
View File
File diff suppressed because it is too large Load Diff
+49
View File
@@ -0,0 +1,49 @@
# Security Fix: Integer Overflow Protection in Cache Serialization
## Summary
Fixed **High severity** integer overflow vulnerability identified by GitHub Advanced Security in PR #117.
## Vulnerability
**Locations**: `universal_cache.go` lines 789 and 811
- `result := make([]byte, len(bytes)+1)` - Raw bytes path
- `result := make([]byte, len(jsonData)+1)` - JSON encoding path
**Risk**: Potential integer overflow when allocating memory for very large cache entries.
## Fix Applied
1. **Added size limit constant**:
```go
maxCacheEntrySize = 64 * 1024 * 1024 // 64 MiB
```
2. **Size validation before allocation**:
- Validates entry size doesn't exceed limit
- Validates adding marker byte won't overflow
- Returns descriptive error messages
3. **Comprehensive test coverage**:
- Oversized byte slices (>64 MiB)
- Exact max size edge case
- Safe sizes (normal operation)
- Large JSON data structures
## Verification
✅ All tests pass with race detection
✅ No security issues (golangci-lint, gosec)
✅ 76.3% test coverage maintained
## Impact
- No breaking changes
- Negligible performance overhead
- Prevents potential buffer overflows
- Predictable memory usage
---
**Date**: January 8, 2026
**Severity**: High → Resolved
+1 -1
View File
@@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("Failed to set authenticated: %v", err)
}
session.SetEmail("user@company.com")
session.SetUserIdentifier("user@company.com")
session.SetIDToken(validJWT)
session.SetAccessToken(validJWT)
+60 -11
View File
@@ -4,8 +4,7 @@ import (
"fmt"
"net/http"
"strings"
"github.com/google/uuid"
"time"
)
// validateRedirectCount checks if redirect limit is exceeded and handles the error
@@ -44,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
// Clear all existing session data
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
session.SetEmail("")
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
@@ -77,7 +76,12 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
csrfToken := uuid.NewString()
csrfToken, err := newUUIDv4()
if err != nil {
t.logger.Errorf("Failed to generate CSRF token: %v", err)
http.Error(rw, "Failed to generate CSRF token", http.StatusInternalServerError)
return
}
nonce, err := generateNonce()
if err != nil {
t.logger.Errorf("Failed to generate nonce: %v", err)
@@ -246,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetUserIdentifier(userIdentifier)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
@@ -286,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
session.SetIDToken("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
session.SetUserIdentifier("")
// Clear CSRF tokens to prevent replay attacks
session.SetCSRF("")
session.SetNonce("")
@@ -334,9 +338,54 @@ func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
strings.Contains(accept, "application/json")
}
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
// This is a heuristic check - actual implementation would depend on
// the specific provider and token metadata
return false // Placeholder implementation
// isNonNavigationRequest reports whether the request is a browser
// sub-resource (script, image, stylesheet, fetch, serviceWorker) rather than
// a top-level HTML navigation. Non-navigation requests MUST NOT trigger an
// OIDC redirect flow: several sub-resource loads happening in parallel would
// each call defaultInitiateAuthentication, each overwriting the session's
// CSRF/nonce, breaking the eventual callback (issue #129).
//
// Detection prefers Sec-Fetch-Mode, which all modern browsers send
// (Chrome/Edge/Firefox/Safari). For older or non-browser clients we fall
// back to Accept: if Accept is present and does not list text/html, treat
// it as a sub-resource. An empty/missing Accept is assumed to be navigation
// (safer to redirect than 401 on an ambiguous request).
func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
if mode := req.Header.Get("Sec-Fetch-Mode"); mode != "" {
return mode != "navigate"
}
accept := req.Header.Get("Accept")
if accept == "" || accept == "*/*" {
return false
}
return !strings.Contains(accept, "text/html")
}
// isRefreshTokenExpired checks whether the stored refresh token is likely
// past its useful lifetime, using the cookie-side issued_at timestamp set by
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
// MaxRefreshTokenAgeSeconds; 0 disables the check).
//
// The point of this check is to short-circuit the refresh path BEFORE the
// thundering herd hits the IdP for a token the provider has almost certainly
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
// style polling clients from looping on invalid_grant after a long pause.
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
if t == nil || session == nil {
return false
}
if t.maxRefreshTokenAge <= 0 {
return false
}
issuedAt := session.GetRefreshTokenIssuedAt()
if issuedAt.IsZero() {
// No timestamp recorded (legacy session pre-dating the issued_at
// field). Don't force a re-auth - attempt refresh once and let the
// IdP be the source of truth.
return false
}
return time.Since(issuedAt) > t.maxRefreshTokenAge
}
+88 -4
View File
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
// Pre-populate session with old data
_ = session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-access-token-with-many-characters")
session.SetRefreshToken("old-refresh-token-with-many-characters")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
// Verify old data is cleared
s.False(session.GetAuthenticated())
s.Empty(session.GetEmail())
s.Empty(session.GetUserIdentifier())
// Verify new data is set
s.Equal(csrfToken, session.GetCSRF())
@@ -305,6 +305,90 @@ func (s *AuthFlowBehaviourSuite) TestIsAjaxRequest() {
}
}
// TestIsNonNavigationRequest verifies browser sub-resource detection used to
// suppress OIDC redirects on parallel static-asset loads (issue #129).
func (s *AuthFlowBehaviourSuite) TestIsNonNavigationRequest() {
testCases := []struct {
headers map[string]string
name string
expectNonNavigation bool
}{
{
name: "Sec-Fetch-Mode navigate",
headers: map[string]string{"Sec-Fetch-Mode": "navigate"},
expectNonNavigation: false,
},
{
name: "Sec-Fetch-Mode no-cors",
headers: map[string]string{"Sec-Fetch-Mode": "no-cors"},
expectNonNavigation: true,
},
{
name: "Sec-Fetch-Mode cors",
headers: map[string]string{"Sec-Fetch-Mode": "cors"},
expectNonNavigation: true,
},
{
name: "Sec-Fetch-Mode same-origin (fetch in page)",
headers: map[string]string{"Sec-Fetch-Mode": "same-origin"},
expectNonNavigation: true,
},
{
name: "Accept text/html (fallback)",
headers: map[string]string{"Accept": "text/html,application/xhtml+xml"},
expectNonNavigation: false,
},
{
name: "Accept image/png (fallback)",
headers: map[string]string{"Accept": "image/png,image/*;q=0.8"},
expectNonNavigation: true,
},
{
name: "Accept application/javascript (fallback)",
headers: map[string]string{"Accept": "application/javascript"},
expectNonNavigation: true,
},
{
name: "Accept */* treated as navigation",
headers: map[string]string{"Accept": "*/*"},
expectNonNavigation: false,
},
{
name: "No Accept header assumed navigation",
headers: map[string]string{},
expectNonNavigation: false,
},
{
name: "Sec-Fetch-Mode beats Accept (navigate wins)",
headers: map[string]string{
"Sec-Fetch-Mode": "navigate",
"Accept": "application/javascript",
},
expectNonNavigation: false,
},
{
name: "Sec-Fetch-Mode beats Accept (no-cors wins)",
headers: map[string]string{
"Sec-Fetch-Mode": "no-cors",
"Accept": "text/html",
},
expectNonNavigation: true,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
req := httptest.NewRequest(http.MethodGet, "/_static/asset.js", nil)
for key, value := range tc.headers {
req.Header.Set(key, value)
}
result := s.tOidc.isNonNavigationRequest(req)
s.Equal(tc.expectNonNavigation, result)
})
}
}
// TestHandleCallback_MissingState tests callback with missing state parameter
func (s *AuthFlowBehaviourSuite) TestHandleCallback_MissingState() {
sessionManager, err := NewSessionManager(
@@ -627,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
session, err := sessionManager.GetSession(req)
s.Require().NoError(err)
_ = session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
session.mainSession.Values["redirect_count"] = 3
@@ -636,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
// Session should be cleared
s.False(session.GetAuthenticated())
s.Empty(session.GetEmail())
s.Empty(session.GetUserIdentifier())
s.Empty(session.GetIDToken())
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
+4 -3
View File
@@ -599,8 +599,9 @@ func GetGlobalTaskMemoryMonitor(logger *Logger) *TaskMemoryMonitor {
return globalTaskMemoryMonitor
}
// NewTaskMemoryMonitor creates a new memory monitor for task registry
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior
// NewTaskMemoryMonitor creates a new memory monitor for task registry.
//
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior.
func NewTaskMemoryMonitor(logger *Logger, registry *TaskRegistry) *TaskMemoryMonitor {
return GetGlobalTaskMemoryMonitor(logger)
}
@@ -712,7 +713,7 @@ func (mm *TaskMemoryMonitor) checkForMemoryIssues(stats TaskMemoryStats) {
// Check for goroutine leaks (arbitrary threshold)
if stats.Goroutines > 100 {
mm.logger.Infof("High goroutine count detected: %d", stats.Goroutines)
mm.logger.Debugf("High goroutine count detected: %d", stats.Goroutines)
}
// Check for heap growth without corresponding GC activity
+23 -14
View File
@@ -29,8 +29,9 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
pressure := monitor.GetMemoryPressure()
assert.Equal(t, MemoryPressureNone, pressure)
// Collect stats to populate lastStats
monitor.GetCurrentStats()
// Explicitly sample to populate lastStats; GetCurrentStats is now a
// cached read and no longer forces a runtime.ReadMemStats.
monitor.Refresh()
// Now should return a valid pressure level
pressure = monitor.GetMemoryPressure()
@@ -46,11 +47,13 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Start monitoring should not panic
// Start monitoring should not panic. Interval is clamped to the
// minimum (30s); we rely on Refresh() when we need a synchronous
// sample instead of waiting for a tick.
assert.NotPanics(t, func() {
ctx := context.Background()
monitor.StartMonitoring(ctx, 100*time.Millisecond)
time.Sleep(GetTestDuration(50 * time.Millisecond))
monitor.StartMonitoring(ctx, 0)
monitor.Refresh()
})
// Clean up
@@ -117,6 +120,9 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Refresh forces a synchronous sample; GetCurrentStats is a cached
// read, so we sample first to guarantee fresh data.
monitor.Refresh()
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
@@ -450,12 +456,12 @@ func TestMemoryMonitorIntegration(t *testing.T) {
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
defer monitor.StopMonitoring()
// Start monitoring
// Start monitoring. The interval is clamped to the minimum (30s) so
// the ticker won't fire during the test; drive the sample manually via
// Refresh() instead.
ctx := context.Background()
monitor.StartMonitoring(ctx, 50*time.Millisecond)
// Wait for at least one check
time.Sleep(GetTestDuration(150 * time.Millisecond))
monitor.StartMonitoring(ctx, 0)
monitor.Refresh()
// Get pressure (should be a valid pressure level)
pressure := monitor.GetMemoryPressure()
@@ -488,6 +494,7 @@ func TestMemoryStatsCollection(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
monitor.Refresh()
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
@@ -501,6 +508,7 @@ func TestMemoryStatsCollection(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
monitor.Refresh()
stats := monitor.GetCurrentStats()
// Should calculate and include pressure level
@@ -521,13 +529,14 @@ func TestMemoryStatsCollection(t *testing.T) {
// Allocate some memory
_ = make([]byte, 1024*1024) // 1MB
// Get stats before GC
beforeStats := monitor.GetCurrentStats()
// Get stats before GC (explicit Refresh so we have a fresh pre-GC
// snapshot to compare against, not the constructor baseline).
beforeStats := monitor.Refresh()
// Trigger GC
// Trigger GC (internally Refresh()es before and after)
monitor.TriggerGC()
// Get stats after GC
// Get stats after GC from cache (TriggerGC already refreshed it)
afterStats := monitor.GetCurrentStats()
// After GC should have different stats
+137
View File
@@ -0,0 +1,137 @@
package traefikoidc
import (
"encoding/pem"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
// testCertPEM returns a valid PEM-encoded certificate harvested from an
// httptest.NewTLSServer. Using httptest keeps the test free of any
// handwritten static cert that could expire.
func testCertPEM(t *testing.T) string {
t.Helper()
srv := httptest.NewTLSServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
t.Cleanup(srv.Close)
cert := srv.Certificate()
if cert == nil {
t.Fatal("httptest.NewTLSServer did not expose a certificate")
}
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
}
func TestLoadCACertPool_Empty(t *testing.T) {
cfg := &Config{}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool != nil {
t.Errorf("expected nil pool when no CA source configured, got %v", pool)
}
}
func TestLoadCACertPool_InlinePEM(t *testing.T) {
cfg := &Config{CACertPEM: testCertPEM(t)}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool for valid CACertPEM")
}
}
func TestLoadCACertPool_InlinePEM_Garbage(t *testing.T) {
cfg := &Config{CACertPEM: "not a pem"}
pool, err := cfg.loadCACertPool()
if err == nil {
t.Fatal("expected error for garbage CACertPEM, got nil")
}
if pool != nil {
t.Errorf("expected nil pool on error, got %v", pool)
}
if !strings.Contains(err.Error(), "caCertPEM") {
t.Errorf("error should name the failing field, got: %v", err)
}
}
func TestLoadCACertPool_FilePath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "ca.pem")
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
t.Fatalf("writing temp PEM: %v", err)
}
cfg := &Config{CACertPath: path}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool for valid CACertPath")
}
}
func TestLoadCACertPool_FilePath_Missing(t *testing.T) {
cfg := &Config{CACertPath: "/does/not/exist/ca.pem"}
pool, err := cfg.loadCACertPool()
if err == nil {
t.Fatal("expected error for missing CACertPath, got nil")
}
if pool != nil {
t.Errorf("expected nil pool on error, got %v", pool)
}
}
func TestLoadCACertPool_Combined(t *testing.T) {
// Both inline and file sources populated — certificates from both should
// be accepted into the same pool.
dir := t.TempDir()
path := filepath.Join(dir, "ca.pem")
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
t.Fatalf("writing temp PEM: %v", err)
}
cfg := &Config{CACertPath: path, CACertPEM: testCertPEM(t)}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool when both sources set")
}
}
func TestSharedTransportPool_ConfigKeyDistinguishesCAAndSkipVerify(t *testing.T) {
p := GetGlobalTransportPool()
cfgSystem := DefaultHTTPClientConfig()
cfgSkip := DefaultHTTPClientConfig()
cfgSkip.InsecureSkipVerify = true
cfgCustomCA := DefaultHTTPClientConfig()
pool, err := (&Config{CACertPEM: testCertPEM(t)}).loadCACertPool()
if err != nil {
t.Fatalf("loadCACertPool: %v", err)
}
cfgCustomCA.RootCAs = pool
keys := map[string]string{
"system": p.configKey(cfgSystem),
"skip": p.configKey(cfgSkip),
"customCA": p.configKey(cfgCustomCA),
}
seen := make(map[string]string, len(keys))
for name, key := range keys {
if dup, ok := seen[key]; ok {
t.Errorf("configKey collision: %s and %s share key %q", name, dup, key)
}
seen[key] = name
}
}
+32 -8
View File
@@ -20,8 +20,9 @@ var (
cacheManagerInitOnce sync.Once
)
// GetGlobalCacheManager returns a singleton CacheManager instance
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
// GetGlobalCacheManager returns a singleton CacheManager instance.
//
// Deprecated: Use GetGlobalCacheManagerWithConfig instead.
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
return GetGlobalCacheManagerWithConfig(wg, nil)
}
@@ -61,7 +62,7 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()}
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true}
}
// GetSharedTokenCache returns the shared token cache
@@ -93,7 +94,7 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()}
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true}
}
// GetSharedTokenTypeCache returns the shared token type cache
@@ -101,7 +102,23 @@ func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()}
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
}
// GetSharedSessionInvalidationCache returns the shared session invalidation cache
// for backchannel and front-channel logout (IdP-initiated logout)
func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
}
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
// by the refresh path to coalesce grants across Traefik replicas via Redis.
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
}
// Close gracefully shuts down all cache components
@@ -121,7 +138,8 @@ func CleanupGlobalCacheManager() error {
// CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface
type CacheInterfaceWrapper struct {
cache *UniversalCache
cache *UniversalCache
managed bool // If true, cache is managed globally and Close() is a no-op
}
// Set stores a value
@@ -149,9 +167,15 @@ func (c *CacheInterfaceWrapper) Cleanup() {
c.cache.Cleanup()
}
// Close shuts down the cache
// Close shuts down the cache if it's not managed globally.
// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding
// when multiple plugin instances are closed during Traefik configuration reloads.
func (c *CacheInterfaceWrapper) Close() {
// Close the underlying cache to stop goroutines
if c.managed {
// Cache is managed globally by UniversalCacheManager, so we don't close it here.
return
}
// Standalone cache - close it properly to stop cleanup goroutines
if c.cache != nil {
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
}
+153
View File
@@ -219,6 +219,159 @@ func TestCacheInterfaceWrapper_Close(t *testing.T) {
nilWrapper.Close()
}
// TestCacheInterfaceWrapper_ManagedClose_Regression tests that managed cache wrappers
// don't close the underlying cache when Close() is called. This is a regression test
// for issue #105 where multiple plugin instances closing shared caches caused log flooding.
func TestCacheInterfaceWrapper_ManagedClose_Regression(t *testing.T) {
cm := getTestCacheManager(t)
// Get a managed cache wrapper
cache := cm.GetSharedTokenBlacklist()
wrapper, ok := cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
// Verify it's marked as managed
if !wrapper.managed {
t.Error("Expected shared cache wrapper to be marked as managed")
}
// Set some data before Close
cache.Set("test-key", "test-value", time.Hour)
// Close the wrapper (should be a no-op for managed caches)
wrapper.Close()
// Verify the cache is still operational after Close
value, found := cache.Get("test-key")
if !found {
t.Error("Expected cache to still work after Close() on managed wrapper")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
// Can still set new values
cache.Set("new-key", "new-value", time.Hour)
newValue, found := cache.Get("new-key")
if !found || newValue != "new-value" {
t.Error("Expected to be able to set new values after Close() on managed wrapper")
}
}
// TestCacheInterfaceWrapper_StandaloneClose tests that standalone cache wrappers
// properly close the underlying cache when Close() is called.
func TestCacheInterfaceWrapper_StandaloneClose(t *testing.T) {
// Create a standalone cache (not from the global cache manager)
standaloneCache := NewCache()
wrapper, ok := standaloneCache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
// Verify it's NOT marked as managed
if wrapper.managed {
t.Error("Expected standalone cache wrapper to NOT be marked as managed")
}
// Set some data
standaloneCache.Set("test-key", "test-value", time.Hour)
// Get baseline goroutine count
baselineGoroutines := runtime.NumGoroutine()
// Close the wrapper (should actually close the underlying cache)
wrapper.Close()
// Give cleanup goroutine time to stop
time.Sleep(100 * time.Millisecond)
// Goroutine count should decrease (cleanup routine stopped)
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > baselineGoroutines {
// This is acceptable - other tests might have started goroutines
t.Logf("Goroutine count: baseline=%d, final=%d", baselineGoroutines, finalGoroutines)
}
}
// TestCacheInterfaceWrapper_MultipleInstancesClose_Regression tests that multiple
// plugin instances can close their cache wrappers without affecting shared caches.
// This is a regression test for issue #105.
func TestCacheInterfaceWrapper_MultipleInstancesClose_Regression(t *testing.T) {
cm := getTestCacheManager(t)
// Simulate multiple plugin instances getting cache references
instances := make([]*CacheInterfaceWrapper, 5)
for i := 0; i < 5; i++ {
cache := cm.GetSharedTokenBlacklist()
wrapper, ok := cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
instances[i] = wrapper
// Each instance might set some data
cache.Set(fmt.Sprintf("instance-%d-key", i), fmt.Sprintf("value-%d", i), time.Hour)
}
// Close all instances (simulating plugin shutdown/reload)
for _, wrapper := range instances {
wrapper.Close()
}
// The shared cache should still work after all instances closed their wrappers
newCache := cm.GetSharedTokenBlacklist()
// Data set by earlier instances should still be accessible
for i := 0; i < 5; i++ {
key := fmt.Sprintf("instance-%d-key", i)
value, found := newCache.Get(key)
if !found {
t.Errorf("Expected data from instance %d to still be accessible", i)
}
expectedValue := fmt.Sprintf("value-%d", i)
if value != expectedValue {
t.Errorf("Expected '%s', got '%v'", expectedValue, value)
}
}
// Should be able to add new data
newCache.Set("after-close-key", "after-close-value", time.Hour)
value, found := newCache.Get("after-close-key")
if !found || value != "after-close-value" {
t.Error("Expected to be able to use cache after all wrapper Close() calls")
}
}
// TestAllSharedCachesMarkedAsManaged verifies all shared cache getters
// return managed wrappers to prevent the log flooding issue.
func TestAllSharedCachesMarkedAsManaged(t *testing.T) {
cm := getTestCacheManager(t)
tests := []struct {
name string
cache CacheInterface
}{
{"TokenBlacklist", cm.GetSharedTokenBlacklist()},
{"IntrospectionCache", cm.GetSharedIntrospectionCache()},
{"TokenTypeCache", cm.GetSharedTokenTypeCache()},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wrapper, ok := tt.cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatalf("Expected CacheInterfaceWrapper for %s", tt.name)
}
if !wrapper.managed {
t.Errorf("%s cache wrapper should be marked as managed", tt.name)
}
})
}
}
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
+2 -2
View File
@@ -7,7 +7,7 @@ import (
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
// MarshalJSON implements custom JSON marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalJSON() ([]byte, error) {
// Build a map manually to avoid type alias issues with yaegi
@@ -47,7 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
return json.Marshal(result)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
// MarshalYAML implements custom YAML marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalYAML() (interface{}, error) {
// Build a map manually to avoid type alias issues with yaegi
+4 -4
View File
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
session.SetCSRF(csrfToken)
session.SetNonce("test-nonce")
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("old-access-token")
session.SetRefreshToken("old-refresh-token")
session.SetIDToken("old-id-token")
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Now perform selective clearing (as done in the fix)
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// Set initial session data
session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-token")
session.SetCSRF("existing-csrf")
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
// NEW BEHAVIOR: Selective clearing
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
+290
View File
@@ -0,0 +1,290 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"context"
"fmt"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/dcrstorage"
)
// DCRStorageBackend represents the type of storage backend for DCR credentials.
// Alias for internal package type for backward compatibility.
type DCRStorageBackend = dcrstorage.StorageBackend
const (
// DCRStorageBackendFile uses file-based storage (default for backward compatibility)
DCRStorageBackendFile DCRStorageBackend = dcrstorage.StorageBackendFile
// DCRStorageBackendRedis uses Redis for distributed storage
DCRStorageBackendRedis DCRStorageBackend = dcrstorage.StorageBackendRedis
// DCRStorageBackendAuto automatically selects Redis if available, otherwise file
DCRStorageBackendAuto DCRStorageBackend = dcrstorage.StorageBackendAuto
)
// DCRCredentialsStore defines the interface for storing DCR credentials.
// This abstraction allows different storage backends (file, Redis) to be used
// for persisting OIDC Dynamic Client Registration credentials across nodes.
type DCRCredentialsStore interface {
// Save stores the client registration response for a provider
// The providerURL is used as a key to support multi-tenant scenarios
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
// Load retrieves stored credentials for a provider
// Returns nil, nil if no credentials exist (not an error)
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
// Delete removes stored credentials for a provider
Delete(ctx context.Context, providerURL string) error
// Exists checks if credentials exist for a provider
Exists(ctx context.Context, providerURL string) (bool, error)
}
// loggerAdapter adapts our Logger to the dcrstorage.Logger interface
type loggerAdapter struct {
logger *Logger
}
func (l *loggerAdapter) Debug(msg string) { l.logger.Debug("%s", msg) }
func (l *loggerAdapter) Debugf(format string, args ...any) { l.logger.Debugf(format, args...) }
func (l *loggerAdapter) Info(msg string) { l.logger.Info("%s", msg) }
func (l *loggerAdapter) Infof(format string, args ...any) { l.logger.Infof(format, args...) }
func (l *loggerAdapter) Error(msg string) { l.logger.Error("%s", msg) }
func (l *loggerAdapter) Errorf(format string, args ...any) { l.logger.Errorf(format, args...) }
// cacheAdapter adapts UniversalCache to dcrstorage.Cache interface
type cacheAdapter struct {
cache *UniversalCache
}
func (c *cacheAdapter) Get(key string) (any, bool) {
return c.cache.Get(key)
}
func (c *cacheAdapter) Set(key string, value any, ttl time.Duration) error {
return c.cache.Set(key, value, ttl)
}
func (c *cacheAdapter) Delete(key string) {
c.cache.Delete(key)
}
// fileStoreWrapper wraps dcrstorage.FileStore to implement DCRCredentialsStore
type fileStoreWrapper struct {
inner *dcrstorage.FileStore
}
func (w *fileStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
innerCreds := convertCredsToInternal(creds)
return w.inner.Save(ctx, providerURL, innerCreds)
}
func (w *fileStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
innerCreds, err := w.inner.Load(ctx, providerURL)
if err != nil || innerCreds == nil {
return nil, err
}
return convertCredsFromInternal(innerCreds), nil
}
func (w *fileStoreWrapper) Delete(ctx context.Context, providerURL string) error {
return w.inner.Delete(ctx, providerURL)
}
func (w *fileStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
return w.inner.Exists(ctx, providerURL)
}
// basePath returns the base path used for storing credentials (for backward compatibility in tests)
func (w *fileStoreWrapper) basePath() string {
return w.inner.BasePath()
}
// getFilePath returns the file path for storing credentials for a specific provider (for backward compatibility in tests)
func (w *fileStoreWrapper) getFilePath(providerURL string) string {
return w.inner.GetFilePath(providerURL)
}
// redisStoreWrapper wraps dcrstorage.RedisStore to implement DCRCredentialsStore
type redisStoreWrapper struct {
inner *dcrstorage.RedisStore
}
func (w *redisStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
innerCreds := convertCredsToInternal(creds)
return w.inner.Save(ctx, providerURL, innerCreds)
}
func (w *redisStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
innerCreds, err := w.inner.Load(ctx, providerURL)
if err != nil || innerCreds == nil {
return nil, err
}
return convertCredsFromInternal(innerCreds), nil
}
func (w *redisStoreWrapper) Delete(ctx context.Context, providerURL string) error {
return w.inner.Delete(ctx, providerURL)
}
func (w *redisStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
return w.inner.Exists(ctx, providerURL)
}
// FileCredentialsStore implements DCRCredentialsStore using file-based storage.
// This is the default storage backend for backward compatibility with existing deployments.
type FileCredentialsStore = fileStoreWrapper
// RedisCredentialsStore implements DCRCredentialsStore using Redis-backed cache.
// This storage backend enables sharing DCR credentials across multiple Traefik instances.
type RedisCredentialsStore = redisStoreWrapper
// NewFileCredentialsStore creates a new file-based credentials store.
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
func NewFileCredentialsStore(basePath string, logger *Logger) *FileCredentialsStore {
var dcrLogger dcrstorage.Logger
if logger != nil {
dcrLogger = &loggerAdapter{logger: logger}
}
inner := dcrstorage.NewFileStore(basePath, dcrLogger)
return &fileStoreWrapper{inner: inner}
}
// NewRedisCredentialsStore creates a new Redis-backed credentials store.
// The cache should be configured with a Redis backend for distributed storage.
// If keyPrefix is empty, defaults to "dcr:creds:"
func NewRedisCredentialsStore(cache *UniversalCache, keyPrefix string, logger *Logger) *RedisCredentialsStore {
var dcrLogger dcrstorage.Logger
if logger != nil {
dcrLogger = &loggerAdapter{logger: logger}
}
cacheAdapt := &cacheAdapter{cache: cache}
inner := dcrstorage.NewRedisStore(cacheAdapt, keyPrefix, dcrLogger)
return &redisStoreWrapper{inner: inner}
}
// Helper functions to convert between main package and internal package types
func convertCredsToInternal(creds *ClientRegistrationResponse) *dcrstorage.ClientRegistrationResponse {
if creds == nil {
return nil
}
return &dcrstorage.ClientRegistrationResponse{
SubjectType: creds.SubjectType,
LogoURI: creds.LogoURI,
RegistrationAccessToken: creds.RegistrationAccessToken,
RegistrationClientURI: creds.RegistrationClientURI,
Scope: creds.Scope,
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
TOSURI: creds.TOSURI,
PolicyURI: creds.PolicyURI,
ClientSecret: creds.ClientSecret,
ApplicationType: creds.ApplicationType,
ClientID: creds.ClientID,
ClientName: creds.ClientName,
JWKSURI: creds.JWKSURI,
ClientURI: creds.ClientURI,
Contacts: creds.Contacts,
GrantTypes: creds.GrantTypes,
ResponseTypes: creds.ResponseTypes,
RedirectURIs: creds.RedirectURIs,
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
ClientIDIssuedAt: creds.ClientIDIssuedAt,
}
}
func convertCredsFromInternal(creds *dcrstorage.ClientRegistrationResponse) *ClientRegistrationResponse {
if creds == nil {
return nil
}
return &ClientRegistrationResponse{
SubjectType: creds.SubjectType,
LogoURI: creds.LogoURI,
RegistrationAccessToken: creds.RegistrationAccessToken,
RegistrationClientURI: creds.RegistrationClientURI,
Scope: creds.Scope,
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
TOSURI: creds.TOSURI,
PolicyURI: creds.PolicyURI,
ClientSecret: creds.ClientSecret,
ApplicationType: creds.ApplicationType,
ClientID: creds.ClientID,
ClientName: creds.ClientName,
JWKSURI: creds.JWKSURI,
ClientURI: creds.ClientURI,
Contacts: creds.Contacts,
GrantTypes: creds.GrantTypes,
ResponseTypes: creds.ResponseTypes,
RedirectURIs: creds.RedirectURIs,
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
ClientIDIssuedAt: creds.ClientIDIssuedAt,
}
}
// NewDCRCredentialsStore creates a DCRCredentialsStore based on configuration.
// This factory function handles backend selection logic:
// - "file": Use file-based storage (default for backward compatibility)
// - "redis": Use Redis exclusively (fails if Redis unavailable)
// - "auto": Use Redis if available, fallback to file
func NewDCRCredentialsStore(
config *DynamicClientRegistrationConfig,
cacheManager *CacheManager,
logger *Logger,
) (DCRCredentialsStore, error) {
if config == nil {
return nil, fmt.Errorf("DCR config is nil")
}
if logger == nil {
logger = GetSingletonNoOpLogger()
}
backend := config.StorageBackend
if backend == "" {
backend = string(DCRStorageBackendAuto) // Default to auto selection
}
switch DCRStorageBackend(backend) {
case DCRStorageBackendFile:
logger.Info("Using file-based storage for DCR credentials")
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
case DCRStorageBackendRedis:
cache := getDCRCache(cacheManager)
if cache == nil {
return nil, fmt.Errorf("redis storage requested but Redis/cache not configured")
}
logger.Info("Using Redis storage for DCR credentials")
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
case DCRStorageBackendAuto:
// Try Redis first, fallback to file
cache := getDCRCache(cacheManager)
if cache != nil && cache.backend != nil {
logger.Info("Auto-selected Redis storage for DCR credentials")
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
}
logger.Info("Redis not available, using file storage for DCR credentials")
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
default:
return nil, fmt.Errorf("unknown DCR storage backend: %s", backend)
}
}
// getDCRCache safely retrieves the DCR credentials cache from the cache manager
func getDCRCache(cacheManager *CacheManager) *UniversalCache {
if cacheManager == nil {
return nil
}
cacheManager.mu.RLock()
defer cacheManager.mu.RUnlock()
if cacheManager.manager == nil {
return nil
}
return cacheManager.manager.GetDCRCredentialsCache()
}
+663
View File
@@ -0,0 +1,663 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
// TestFileCredentialsStore_SaveLoad tests the file-based credentials store
func TestFileCredentialsStore_SaveLoad(t *testing.T) {
t.Parallel()
// Create a temp directory for test files
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
testCreds := &ClientRegistrationResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "test-access-token",
RegistrationClientURI: "https://example.com/register/test-client-id",
RedirectURIs: []string{"https://app.example.com/callback"},
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "client_secret_basic",
}
ctx := context.Background()
providerURL := "https://auth.example.com"
t.Run("save and load credentials", func(t *testing.T) {
// Save credentials
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
// Load credentials
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
// Verify fields
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
tempDir2 := t.TempDir()
store2 := NewFileCredentialsStore(filepath.Join(tempDir2, "nonexistent.json"), logger)
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent file: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if exists {
t.Error("Expected credentials to not exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("delete non-existent credentials", func(t *testing.T) {
// Should not error
err := store.Delete(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Delete should not error for non-existent: %v", err)
}
})
}
// TestFileCredentialsStore_MultiProvider tests multi-provider support
func TestFileCredentialsStore_MultiProvider(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
provider1 := "https://auth1.example.com"
provider2 := "https://auth2.example.com"
creds1 := &ClientRegistrationResponse{
ClientID: "client-1",
ClientSecret: "secret-1",
}
creds2 := &ClientRegistrationResponse{
ClientID: "client-2",
ClientSecret: "secret-2",
}
// Save credentials for both providers
if err := store.Save(ctx, provider1, creds1); err != nil {
t.Fatalf("Failed to save creds1: %v", err)
}
if err := store.Save(ctx, provider2, creds2); err != nil {
t.Fatalf("Failed to save creds2: %v", err)
}
// Load and verify each provider's credentials
loaded1, err := store.Load(ctx, provider1)
if err != nil {
t.Fatalf("Failed to load creds1: %v", err)
}
if loaded1.ClientID != "client-1" {
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
}
loaded2, err := store.Load(ctx, provider2)
if err != nil {
t.Fatalf("Failed to load creds2: %v", err)
}
if loaded2.ClientID != "client-2" {
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
}
// Delete one shouldn't affect the other
if err := store.Delete(ctx, provider1); err != nil {
t.Fatalf("Failed to delete creds1: %v", err)
}
exists, _ := store.Exists(ctx, provider2)
if !exists {
t.Error("Provider 2 credentials should still exist")
}
}
// TestFileCredentialsStore_ConcurrentAccess tests thread safety
func TestFileCredentialsStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
creds := &ClientRegistrationResponse{
ClientID: "test-client",
ClientSecret: "test-secret",
}
var wg sync.WaitGroup
concurrency := 10
// Concurrent saves
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = store.Save(ctx, providerURL, creds)
}()
}
wg.Wait()
// Concurrent loads
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = store.Load(ctx, providerURL)
}()
}
wg.Wait()
// Final verification
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load after concurrent access: %v", err)
}
if loaded == nil || loaded.ClientID != "test-client" {
t.Error("Credentials corrupted after concurrent access")
}
}
// TestFileCredentialsStore_InvalidInput tests error handling
func TestFileCredentialsStore_InvalidInput(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
t.Run("empty provider URL uses default path", func(t *testing.T) {
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "", creds)
if err != nil {
t.Fatalf("Save with empty provider URL failed: %v", err)
}
loaded, err := store.Load(ctx, "")
if err != nil {
t.Fatalf("Load with empty provider URL failed: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials with empty provider URL")
}
})
}
// TestFileCredentialsStore_DefaultPath tests default path behavior
func TestFileCredentialsStore_DefaultPath(t *testing.T) {
t.Parallel()
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore("", logger)
// Just verify we can create with empty path and it has a default
if store.basePath() == "" {
t.Error("Expected default base path")
}
}
// TestRedisCredentialsStore_WithMemoryCache tests Redis store with in-memory cache
func TestRedisCredentialsStore_WithMemoryCache(t *testing.T) {
t.Parallel()
// Create an in-memory cache for testing
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
testCreds := &ClientRegistrationResponse{
ClientID: "redis-test-client",
ClientSecret: "redis-test-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "redis-test-token",
RedirectURIs: []string{"https://app.example.com/callback"},
}
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
}
// TestRedisCredentialsStore_TTLFromExpiry tests TTL calculation
func TestRedisCredentialsStore_TTLFromExpiry(t *testing.T) {
t.Parallel()
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
t.Run("expired credentials should fail", func(t *testing.T) {
expiredCreds := &ClientRegistrationResponse{
ClientID: "expired-client",
ClientSecret: "expired-secret",
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), // Already expired
}
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
if err == nil {
t.Error("Expected error for expired credentials")
}
})
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
creds := &ClientRegistrationResponse{
ClientID: "no-expiry-client",
ClientSecret: "no-expiry-secret",
ClientSecretExpiresAt: 0, // No expiry
}
err := store.Save(ctx, "https://noexpiry.example.com", creds)
if err != nil {
t.Fatalf("Failed to save credentials without expiry: %v", err)
}
})
}
// TestRedisCredentialsStore_InvalidInput tests error handling
func TestRedisCredentialsStore_InvalidInput(t *testing.T) {
t.Parallel()
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
}
// TestDCRStorageFactory tests the factory function
func TestDCRStorageFactory(t *testing.T) {
t.Parallel()
logger := GetSingletonNoOpLogger()
t.Run("nil config returns error", func(t *testing.T) {
_, err := NewDCRCredentialsStore(nil, nil, logger)
if err == nil {
t.Error("Expected error for nil config")
}
})
t.Run("file backend creates file store", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "file",
CredentialsFile: "/tmp/test-creds.json",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create file store: %v", err)
}
if store == nil {
t.Error("Expected store but got nil")
}
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore")
}
})
t.Run("redis backend without cache manager returns error", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "redis",
}
_, err := NewDCRCredentialsStore(config, nil, logger)
if err == nil {
t.Error("Expected error for redis backend without cache manager")
}
})
t.Run("auto backend without redis falls back to file", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "auto",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create auto store: %v", err)
}
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore for auto without redis")
}
})
t.Run("unknown backend returns error", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "unknown",
}
_, err := NewDCRCredentialsStore(config, nil, logger)
if err == nil {
t.Error("Expected error for unknown backend")
}
})
t.Run("empty backend defaults to auto", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create store with empty backend: %v", err)
}
// Should default to file (auto without redis)
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore for empty backend")
}
})
}
// TestDynamicClientRegistrar_WithStore tests registrar with store
func TestDynamicClientRegistrar_WithStore(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
}
registrar := NewDynamicClientRegistrarWithStore(
nil, // httpClient
logger,
config,
"https://auth.example.com",
store,
)
if registrar == nil {
t.Fatal("Expected registrar but got nil")
}
if registrar.store == nil {
t.Error("Expected store to be set")
}
// Test SetStore
newStore := NewFileCredentialsStore(filepath.Join(tempDir, "new.json"), logger)
registrar.SetStore(newStore)
if registrar.store != newStore {
t.Error("SetStore did not update the store")
}
}
// TestDynamicClientRegistrar_CredentialsFromStore tests loading from store
func TestDynamicClientRegistrar_CredentialsFromStore(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
providerURL := "https://auth.example.com"
ctx := context.Background()
// Pre-save credentials
testCreds := &ClientRegistrationResponse{
ClientID: "pre-saved-client",
ClientSecret: "pre-saved-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
}
if err := store.Save(ctx, providerURL, testCreds); err != nil {
t.Fatalf("Failed to pre-save credentials: %v", err)
}
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
}
registrar := NewDynamicClientRegistrarWithStore(
nil,
logger,
config,
providerURL,
store,
)
// Test loading via the internal method
loaded, err := registrar.loadCredentialsFromStore(ctx)
if err != nil {
t.Fatalf("Failed to load from store: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != "pre-saved-client" {
t.Errorf("ClientID mismatch: got %s", loaded.ClientID)
}
}
// TestFileCredentialsStore_CorruptedFile tests handling of corrupted files
func TestFileCredentialsStore_CorruptedFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
// Write corrupted JSON
filePath := store.getFilePath(providerURL)
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
t.Fatalf("Failed to write corrupted file: %v", err)
}
// Should return error for corrupted file
_, err := store.Load(ctx, providerURL)
if err == nil {
t.Error("Expected error for corrupted JSON")
}
}
// TestFileCredentialsStore_DirectoryCreation tests auto directory creation
func TestFileCredentialsStore_DirectoryCreation(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(deepPath, logger)
ctx := context.Background()
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "https://example.com", creds)
if err != nil {
t.Fatalf("Failed to save with nested directory: %v", err)
}
loaded, err := store.Load(ctx, "https://example.com")
if err != nil {
t.Fatalf("Failed to load after nested directory creation: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials from nested directory")
}
}
+7 -4
View File
@@ -25,7 +25,10 @@ The **audience** (`aud`) claim in a JWT identifies the intended recipient of the
### Why Does This Matter?
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
Audience validation rejects access tokens whose `aud` claim does not match the
expected audience, blocking the trivial form of token confusion where a token
issued for API A is presented to API B. (Defence in depth — pair with
short-lived tokens, rotation, and per-API client credentials.)
---
@@ -137,8 +140,8 @@ http:
**Recommended:** `true` for production
**What it does:**
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
- When `true`: On audience mismatch, the middleware does **not** silently fall back to ID-token validation. It tries to refresh the access token first; if no refresh token is present (or refresh fails), the user is re-authenticated.
- When `false`: Logs warnings and falls back to ID-token validation (backward compatible).
**Example:**
```yaml
@@ -349,7 +352,7 @@ When opaque tokens are detected:
**Cache behavior:**
- Cache key: Token hash
- TTL: 5 minutes or token expiry (whichever is shorter)
- TTL: 5 minutes; if the token's `exp` is sooner, the cache entry expires at `exp` instead. Tokens without `exp` use the flat 5-minute TTL.
- Reduces introspection requests for frequently used tokens
---
+68 -7
View File
@@ -52,7 +52,7 @@ spec:
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
| `rateLimit` | int | `100` | Maximum requests per second |
| `excludedURLs` | []string | none | Paths that bypass authentication |
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
@@ -62,13 +62,40 @@ spec:
### TLS Termination at Load Balancer
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
the correct default behind any TLS-terminating load balancer (AWS ALB, Google
Cloud LB, Azure App Gateway) — `X-Forwarded-Proto` cannot be trusted (ALB may
overwrite it).
```yaml
forceHTTPS: true # Required for correct redirect URIs
```
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
dev). Otherwise leave it at default.
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
### Streaming Endpoints (SSE and WebSocket)
The middleware automatically bypasses the OIDC redirect for two request kinds
that browsers cannot follow a 302 on:
| Bypass | Triggered by |
|--------|--------------|
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
These requests do **not** require any explicit configuration — they are
handled implicitly. However, the bypass is **not** unauthenticated:
- A valid, encrypted session cookie is required. Requests without one are
rejected (the connection cannot proceed to the backend).
- The session cookie is sealed with `sessionEncryptionKey`, so the
`authenticated` flag cannot be forged.
- Validation is cookie-only — no JWK fetch / signature verification — so
streaming endpoints keep working when the OIDC provider is briefly
unavailable.
- The user identifier from the session is forwarded as `X-Forwarded-User`
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
For browser clients, the user must complete the normal OIDC flow on a
regular HTTP page first; the resulting session cookie is then reused on the
SSE / WebSocket connection.
---
@@ -113,6 +140,7 @@ strictAudienceValidation: true
|-----------|------|---------|-------------|
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
| `cookieDomain` | string | auto-detected | Domain for session cookies |
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
@@ -384,10 +412,14 @@ scopes:
### Dynamic Client Registration (RFC 7591)
Dynamic Client Registration allows the middleware to automatically register itself with the OIDC provider, eliminating the need to manually create client credentials.
**Basic Configuration (Single Instance):**
```yaml
dynamicClientRegistration:
enabled: true
initialAccessToken: "your-token" # Optional
initialAccessToken: "your-token" # Optional, if provider requires it
persistCredentials: true
credentialsFile: "/tmp/oidc-credentials.json"
clientMetadata:
@@ -400,6 +432,35 @@ dynamicClientRegistration:
- "refresh_token"
```
**Multi-Replica Deployment (Kubernetes):**
For Kubernetes deployments with multiple replicas, use Redis storage to share credentials across all instances and prevent registration race conditions:
```yaml
dynamicClientRegistration:
enabled: true
persistCredentials: true
storageBackend: "redis" # Share credentials via Redis
redisKeyPrefix: "myapp:dcr:" # Optional custom prefix
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
redis:
enabled: true
address: "redis:6379"
cacheMode: "redis"
```
**Storage Backend Options:**
| Backend | Description | Use Case |
|---------|-------------|----------|
| `file` | Store credentials in local file | Single instance deployments |
| `redis` | Store credentials in Redis | Multi-replica Kubernetes deployments |
| `auto` | Use Redis if available, fallback to file | Flexible deployments (default) |
### Multi-Replica Deployment
Without Redis, disable replay detection:
+95
View File
@@ -0,0 +1,95 @@
# Dynamic Client Registration (RFC 7591)
The middleware can register itself with an OIDC provider at startup instead of
using a pre-provisioned `clientID` / `clientSecret`. Useful for multi-tenant
deployments, self-service integrations, and ephemeral environments.
## How it works
1. Middleware reads `registration_endpoint` from `.well-known/openid-configuration`.
2. If `clientID` is empty, it `POST`s `clientMetadata` to the registration endpoint.
3. Returned `client_id` / `client_secret` are cached, optionally persisted.
4. Subsequent requests use the registered credentials.
For multi-replica deployments, set `storageBackend: redis` so all replicas
share one client and avoid registration races.
## Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-dcr
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://your-oidc-provider.com
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
dynamicClientRegistration:
enabled: true
persistCredentials: true
storageBackend: redis # file | redis | auto
initialAccessToken: "" # optional, for protected endpoints
registrationEndpoint: "" # optional, override discovery
credentialsFile: /tmp/oidc-client-credentials.json
redisKeyPrefix: "dcr:creds:"
clientMetadata:
redirect_uris:
- https://app.example.com/oauth2/callback
client_name: My Application
application_type: web
grant_types: [authorization_code, refresh_token]
response_types: [code]
token_endpoint_auth_method: client_secret_basic
contacts: [admin@example.com]
```
## Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `enabled` | `false` | Enable DCR. |
| `persistCredentials` | `false` | Save returned credentials for reuse across restarts. |
| `storageBackend` | `auto` | `file`, `redis`, or `auto` (Redis if available, else file). |
| `credentialsFile` | `/tmp/oidc-client-credentials.json` | Path for file-backed storage. Mode `0600`. |
| `redisKeyPrefix` | (none — set explicitly) | Key prefix for Redis-backed storage. The code does not inject a default; if unset, keys have no prefix. `dcr:creds:` is a sensible convention. |
| `registrationEndpoint` | discovered | Override the discovered endpoint. |
| `initialAccessToken` | none | Bearer token for protected registration endpoints. |
| `clientMetadata.redirect_uris` | required | Callback URIs for the OAuth flow. |
| `clientMetadata.client_name` | none | Human-readable client name. |
| `clientMetadata.application_type` | `web` | `web` or `native`. |
| `clientMetadata.grant_types` | `[authorization_code, refresh_token]` | OAuth grant types. |
| `clientMetadata.response_types` | `[code]` | OAuth response types. |
| `clientMetadata.token_endpoint_auth_method` | `client_secret_basic` | `client_secret_basic`, `client_secret_post`, or `none`. |
| `clientMetadata.scope` | none | Space-separated scopes. |
| `clientMetadata.contacts` | none | Admin email addresses. |
| `clientMetadata.logo_uri` | none | Logo URL for consent screens. |
| `clientMetadata.client_uri` | none | Client homepage URL. |
| `clientMetadata.policy_uri` | none | Privacy policy URL. |
| `clientMetadata.tos_uri` | none | Terms of service URL. |
## Provider support
The middleware does not gate DCR by provider — if the provider exposes a
`registration_endpoint` in its discovery document (or you set
`registrationEndpoint` explicitly), DCR will attempt registration. The table
below is informational guidance based on each provider's published support.
| Provider | DCR | Notes |
|----------|-----|-------|
| Keycloak | Yes | Enable in realm settings. |
| Auth0 | Yes | Requires Management API token. |
| Okta | Yes | Enable Dynamic Client Registration in admin console. |
| Azure AD | Limited | Use App Registration API instead. |
| Google | No | Manual registration required. |
| AWS Cognito | No | Manual registration required. |
## Security notes
- Registration endpoints must be HTTPS (loopback excepted for local dev).
- Use `initialAccessToken` in production to gate registration.
- File-backed credentials use `0600`; protect the mount path.
- The plugin marks credentials invalid when within ~5 min of `client_secret_expires_at` but does **not** automatically re-register. If your provider sets a non-zero expiry, schedule manual rotation (delete the credentials file or Redis entry, restart) before that time.
+20 -99
View File
@@ -16,9 +16,8 @@ Guide for local development, testing, and contributing to the Traefik OIDC middl
## Prerequisites
- **Go 1.23+** for plugin compilation
- **Docker & Docker Compose** for local testing
- **OIDC Provider** credentials (Google, Azure, etc.)
- **Go 1.24+** (matches `go.mod`; CI runs Go 1.24.11)
- **OIDC Provider** credentials (Google, Azure, etc.) for any end-to-end test against a real provider
### Required Development Tools
@@ -40,110 +39,32 @@ go install golang.org/x/vuln/cmd/govulncheck@latest
## Local Development Setup
### Docker Compose Environment
The repository includes a Docker Compose setup for testing the plugin locally.
#### 1. Host Configuration
Add to `/etc/hosts`:
### Build and unit tests
```bash
127.0.0.1 hello.localhost
127.0.0.1 traefik.localhost
go mod tidy
go build ./...
go test ./... -short # fast loop, < 30 s
go test -race -timeout=15m ./...
```
#### 2. Plugin Configuration
### Sample plugin configurations
The plugin is loaded using Traefik's **local plugins mode**:
Working middleware/Traefik configs live in [`examples/`](../examples/):
- Plugin source: Parent directory (`../`)
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
- Configuration: `experimental.localPlugins` in `traefik.yml`
- `complete-traefik-config.yaml` — full middleware example
- `redis-config.yaml` — Redis cache configuration
#### 3. OIDC Provider Setup
To run the plugin against a real Traefik instance, drop the project on disk
and load it via `experimental.localPlugins` in your Traefik static config —
see the [README install section](../README.md#install).
Edit `docker/dynamic.yml` with your provider details:
### Integration tests
**Google:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-client-id.apps.googleusercontent.com"
clientSecret: "your-google-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
scopes:
- "openid"
- "email"
- "profile"
```
**Azure AD:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
scopes:
- "openid"
- "email"
- "profile"
```
#### 4. Start Environment
Integration tests live in `integration/`. Run them explicitly:
```bash
cd docker
docker-compose up -d
```
#### 5. Test Plugin
- **Protected App**: http://hello.localhost (redirects to OIDC)
- **Traefik Dashboard**: http://traefik.localhost:8080
### Development Workflow
1. **Edit plugin code** in the project root
2. **Build and test** (optional syntax check):
```bash
go mod tidy
go build .
go test ./...
```
3. **Restart Traefik** to reload plugin:
```bash
docker-compose restart traefik
```
4. **Test changes** at http://hello.localhost
### Debugging
**View plugin logs:**
```bash
docker-compose logs -f traefik | grep traefikoidc
```
**Check plugin loading:**
```bash
docker-compose logs traefik | grep -i plugin
```
**Verify plugin directory:**
```bash
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
go test ./integration/... -run Integration -v
```
---
@@ -299,7 +220,7 @@ The repository uses GitHub Actions for comprehensive validation with 20+ paralle
#### Testing (9 suites)
- Race Detector
- Coverage (75% threshold)
- Coverage (70% threshold, enforced in `pr.yaml`)
- Memory Leaks
- Integration Tests
- Regression Tests
@@ -323,13 +244,13 @@ Tests run in parallel for:
#### Performance & Build (3 checks)
- Benchmarks
- Multi-platform Build (linux/darwin x amd64/arm64)
- Go Version Compatibility (Go 1.23 & 1.24)
- Go Version Compatibility (currently Go 1.24.11 in CI)
### Quality Gates
All PRs must pass:
- All parallel checks
- 75% test coverage minimum
- 70% test coverage minimum
- Zero security vulnerabilities
- No race conditions
- No memory leaks
+5 -3
View File
@@ -23,10 +23,10 @@ Configuration reference for each supported OIDC provider.
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|----------|-------------|----------------|----------------|-----------|
| Google | Full | Yes | `accounts.google.com` | Yes |
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
| Azure AD | Full | Yes | `login.microsoftonline.com`, `sts.windows.net` | Yes |
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
| Okta | Full | Yes | `*.okta.com` | Yes |
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
| Okta | Full | Yes | `*.okta.com`, `*.oktapreview.com`, `*.okta-emea.com` | Yes |
| Keycloak | Full | Yes | host containing `keycloak`, or `/realms/` in path (matches both `/auth/realms/` legacy and `/realms/` modern) | Yes |
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
| GitLab | Full | Yes | `gitlab.com` | Yes |
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
@@ -353,6 +353,8 @@ allowPrivateIPAddresses: true # Required for private IPs
- Roles: User Client Role mapper with "Add to ID token" enabled
- Groups: Group Membership mapper with "Add to ID token" enabled
See [KEYCLOAK_SETUP_GUIDE.md](KEYCLOAK_SETUP_GUIDE.md) for detailed step-by-step setup instructions, mapper configuration, troubleshooting, and performance optimization.
---
## AWS Cognito
+14 -6
View File
@@ -109,11 +109,11 @@ redis:
| `writeTimeout` | int | `3` | Write timeout (seconds) |
| `enableTLS` | bool | `false` | Enable TLS for connections |
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
| `enableCircuitBreaker` | bool | `false` | Wrap the Redis backend with a circuit breaker. **Recommended `true` in production.** |
| `circuitBreakerThreshold` | int | `5` | Consecutive failures before the circuit opens (only when `enableCircuitBreaker: true`). |
| `circuitBreakerTimeout` | int | `60` | Seconds the circuit stays open before allowing a probe (only when `enableCircuitBreaker: true`). |
| `enableHealthCheck` | bool | `false` | Wrap the Redis backend with periodic health checks. **Recommended `true` in production.** |
| `healthCheckInterval` | int | `30` | Health check interval in seconds (only when `enableHealthCheck: true`). |
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
@@ -134,13 +134,21 @@ REDIS_READ_TIMEOUT=3
REDIS_WRITE_TIMEOUT=3
REDIS_ENABLE_TLS=false
REDIS_TLS_SKIP_VERIFY=false
REDIS_HYBRID_L1_SIZE=500
REDIS_HYBRID_L1_MEMORY_MB=10
```
> Resilience fields (`enableCircuitBreaker`, `enableHealthCheck`,
> `circuitBreakerThreshold`, `circuitBreakerTimeout`, `healthCheckInterval`)
> have no environment variable fallback — set them in plugin configuration.
Invalid `cacheMode` values are rejected at plugin startup.
---
## Cache Modes
### Memory Mode (Default without Redis)
### Memory Mode (used when Redis is disabled)
```yaml
redis:
+2 -2
View File
@@ -6,8 +6,8 @@ Comprehensive testing infrastructure for traefikoidc.
| Metric | Value |
|--------|-------|
| Test files | 99 |
| Lines of test code | ~65,500 |
| Test files | 110 |
| Lines of test code | ~72,000 |
| Code coverage | 71.0% |
| Race conditions | None (all pass with `-race`) |
+121 -2
View File
@@ -90,6 +90,7 @@
<a href="#configuration" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Configuration</a>
<a href="#deployment" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Deployment</a>
<a href="#security" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Security</a>
<a href="#logout" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Logout</a>
</div>
<div class="flex items-center space-x-4">
<button id="theme-toggle" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 p-2 min-w-[44px] min-h-[44px] flex items-center justify-center" aria-label="Toggle theme">
@@ -114,6 +115,7 @@
<a href="#configuration" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Configuration</a>
<a href="#deployment" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Deployment</a>
<a href="#security" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Security</a>
<a href="#logout" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Logout</a>
</div>
</div>
</nav>
@@ -193,7 +195,7 @@
</div>
<div>
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-1">Dynamic Registration</h3>
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration for automatic client setup without manual configuration</p>
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration with Redis storage support for multi-replica deployments</p>
</div>
</div>
</div>
@@ -716,6 +718,11 @@ spec:
<td class="py-2 px-3">86400</td>
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
<td class="py-2 px-3">21600</td>
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
<td class="py-2 px-3">_oidc_raczylo_</td>
@@ -856,7 +863,54 @@ spec:
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Enable TLS for Redis connections</td>
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
</tr>
</tbody>
</table>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Dynamic Client Registration (RFC 7591)</h3>
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Automatically register your application with the OIDC provider. Supports Redis storage for multi-replica deployments:</p>
<div class="overflow-x-auto mb-4">
<table class="w-full text-sm">
<thead>
<tr class="border-b border-gray-200 dark:border-gray-700">
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Parameter</th>
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Default</th>
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Description</th>
</tr>
</thead>
<tbody class="text-gray-600 dark:text-gray-400">
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.enabled</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Enable dynamic client registration</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.persistCredentials</code></td>
<td class="py-2 px-3">true</td>
<td class="py-2 px-3">Persist registered credentials across restarts</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.storageBackend</code></td>
<td class="py-2 px-3">auto</td>
<td class="py-2 px-3">Storage backend: <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">file</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis</code>, or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">auto</code> (uses Redis if available)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.redisKeyPrefix</code></td>
<td class="py-2 px-3">dcr:creds:</td>
<td class="py-2 px-3">Redis key prefix for DCR credentials</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.clientMetadata.redirect_uris</code></td>
<td class="py-2 px-3">-</td>
<td class="py-2 px-3">Redirect URIs for the registered client (required)</td>
</tr>
</tbody>
</table>
@@ -1177,6 +1231,71 @@ spec:
</div>
</section>
<!-- IdP-Initiated Logout Section -->
<section id="logout" class="py-12 sm:py-16 md:py-20 bg-white dark:bg-gray-900 theme-transition">
<div class="max-w-6xl mx-auto px-4 sm:px-6">
<div class="text-center mb-8 sm:mb-12">
<h2 class="text-2xl sm:text-3xl md:text-4xl font-bold text-gray-900 dark:text-gray-100 mb-3 sm:mb-4">IdP-Initiated Logout</h2>
<p class="text-base sm:text-lg text-gray-600 dark:text-gray-300 px-4">Support for OIDC Back-Channel and Front-Channel Logout specifications</p>
</div>
<div class="max-w-4xl mx-auto">
<div class="grid md:grid-cols-2 gap-6 mb-8">
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
<i class="fas fa-server mr-2 text-blue-500"></i>
Back-Channel Logout
</h3>
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
Server-to-server logout notification. The IdP sends a signed JWT (logout_token) directly to your application when a user logs out.
</p>
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
<li>&#8226; Signed JWT logout tokens</li>
<li>&#8226; Session ID (sid) based invalidation</li>
<li>&#8226; Subject (sub) based invalidation</li>
<li>&#8226; Works behind firewalls</li>
</ul>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
<i class="fas fa-browser mr-2 text-purple-500"></i>
Front-Channel Logout
</h3>
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
Browser-based logout via iframe. The IdP embeds an iframe pointing to your logout endpoint during user logout.
</p>
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
<li>&#8226; Iframe-based session termination</li>
<li>&#8226; Immediate cookie invalidation</li>
<li>&#8226; Simple GET request handling</li>
<li>&#8226; Issuer validation</li>
</ul>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Configuration Example</h3>
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
# ... other OIDC configuration ...
# Back-Channel Logout (server-to-server)
enableBackchannelLogout: true
backchannelLogoutURL: "/backchannel-logout"
# Front-Channel Logout (browser-based)
enableFrontchannelLogout: true
frontchannelLogoutURL: "/frontchannel-logout"</code></pre>
<p class="text-gray-600 dark:text-gray-400 text-sm mt-4">
Configure your IdP with the full URLs (e.g., <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">https://your-app.example.com/backchannel-logout</code>).
When a user logs out from the IdP, all their sessions across your applications will be invalidated.
</p>
</div>
</div>
</div>
</section>
<!-- Why Choose Section -->
<section class="py-12 sm:py-16 md:py-20 bg-gray-50 dark:bg-gray-800 theme-transition">
<div class="max-w-6xl mx-auto px-4 sm:px-6">
+80 -11
View File
@@ -50,6 +50,7 @@ type DynamicClientRegistrar struct {
logger *Logger
config *DynamicClientRegistrationConfig
registrationResponse *ClientRegistrationResponse
store DCRCredentialsStore // Storage backend for credentials
providerURL string
mu sync.RWMutex
}
@@ -73,8 +74,37 @@ func NewDynamicClientRegistrar(
}
}
// NewDynamicClientRegistrarWithStore creates a new dynamic client registrar with a specific storage backend
func NewDynamicClientRegistrarWithStore(
httpClient *http.Client,
logger *Logger,
dcrConfig *DynamicClientRegistrationConfig,
providerURL string,
store DCRCredentialsStore,
) *DynamicClientRegistrar {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &DynamicClientRegistrar{
httpClient: httpClient,
logger: logger,
config: dcrConfig,
providerURL: providerURL,
store: store,
}
}
// SetStore sets the credentials store for the registrar
// This allows setting the store after creation when the cache manager is available
func (r *DynamicClientRegistrar) SetStore(store DCRCredentialsStore) {
r.mu.Lock()
defer r.mu.Unlock()
r.store = store
}
// RegisterClient performs dynamic client registration with the OIDC provider
// It first attempts to load existing credentials from a file if persistence is enabled,
// It first attempts to load existing credentials from storage if persistence is enabled,
// then registers a new client if no valid credentials exist.
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
if r.config == nil || !r.config.Enabled {
@@ -83,10 +113,13 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
// Try to load existing credentials if persistence is enabled
if r.config.PersistCredentials {
if resp, err := r.loadCredentials(); err == nil && resp != nil {
resp, err := r.loadCredentialsFromStore(ctx)
if err != nil {
r.logger.Debugf("Failed to load credentials from store: %v", err)
} else if resp != nil {
// Check if credentials are still valid (not expired)
if r.areCredentialsValid(resp) {
r.logger.Info("Loaded existing client credentials from file")
r.logger.Info("Loaded existing client credentials from storage")
r.mu.Lock()
r.registrationResponse = resp
r.mu.Unlock()
@@ -179,7 +212,7 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
// Persist credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist client credentials: %v", err)
// Don't fail registration if persistence fails
}
@@ -315,7 +348,44 @@ func (r *DynamicClientRegistrar) credentialsFilePath() string {
return "/tmp/oidc-client-credentials.json"
}
// saveCredentials persists client credentials to a file
// loadCredentialsFromStore loads client credentials from the configured storage backend
// Falls back to legacy file-based loading if no store is configured
func (r *DynamicClientRegistrar) loadCredentialsFromStore(ctx context.Context) (*ClientRegistrationResponse, error) {
// Use store if available
if r.store != nil {
return r.store.Load(ctx, r.providerURL)
}
// Fallback to legacy file-based loading
return r.loadCredentials()
}
// saveCredentialsToStore persists client credentials to the configured storage backend
// Falls back to legacy file-based saving if no store is configured
func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, resp *ClientRegistrationResponse) error {
// Use store if available
if r.store != nil {
return r.store.Save(ctx, r.providerURL, resp)
}
// Fallback to legacy file-based saving
return r.saveCredentials(resp)
}
// deleteCredentialsFromStore removes credentials from the configured storage backend
// Falls back to legacy file-based deletion if no store is configured
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
// Use store if available
if r.store != nil {
return r.store.Delete(ctx, r.providerURL)
}
// Fallback to legacy file-based deletion
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// saveCredentials persists client credentials to a file (legacy method)
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
filePath := r.credentialsFilePath()
@@ -333,7 +403,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
return nil
}
// loadCredentials loads client credentials from a file
// loadCredentials loads client credentials from a file (legacy method)
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
filePath := r.credentialsFilePath()
@@ -420,7 +490,7 @@ func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (
// Persist updated credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist updated credentials: %v", err)
}
}
@@ -527,11 +597,10 @@ func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) e
r.registrationResponse = nil
r.mu.Unlock()
// Remove credentials file if persistence is enabled
// Remove credentials from storage if persistence is enabled
if r.config.PersistCredentials {
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
r.logger.Errorf("Failed to remove credentials file: %v", err)
if err := r.deleteCredentialsFromStore(ctx); err != nil {
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
}
}
+27
View File
@@ -2,6 +2,8 @@ package traefikoidc
import (
"context"
"crypto"
"fmt"
"net/http"
"sync"
"sync/atomic"
@@ -40,6 +42,31 @@ func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, http
return m.JWKS, m.Err
}
func (m *EnhancedMockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
jwks, err := m.GetJWKS(ctx, jwksURL, httpClient)
if err != nil {
return nil, err
}
if jwks == nil {
return nil, fmt.Errorf("JWKS is nil")
}
for i := range jwks.Keys {
k := &jwks.Keys[i]
if k.Kid != kid {
continue
}
switch k.Kty {
case "RSA":
return k.ToRSAPublicKey()
case "EC":
return k.ToECDSAPublicKey()
default:
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
}
}
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
}
func (m *EnhancedMockJWKCache) Cleanup() {
atomic.AddInt32(&m.CleanupCalls, 1)
m.mu.Lock()
+1 -1
View File
@@ -954,7 +954,7 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
gd.mutex.RLock()
defer gd.mutex.RUnlock()
var degraded []string
degraded := make([]string, 0, len(gd.degradedServices))
for serviceName := range gd.degradedServices {
degraded = append(degraded, serviceName)
}
-1
View File
@@ -4,7 +4,6 @@ go 1.24.0
require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
-2
View File
@@ -12,8 +12,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
+16
View File
@@ -17,6 +17,21 @@ import (
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// newUUIDv4 returns an RFC 4122 v4 UUID string (e.g.
// "f47ac10b-58cc-4372-a567-0e02b2c3d479") backed by crypto/rand. Used for CSRF
// tokens and other opaque random identifiers — replaces github.com/google/uuid
// to keep the plugin stdlib-only on the production path.
func newUUIDv4() (string, error) {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return "", fmt.Errorf("could not generate UUID: %w", err)
}
b[6] = (b[6] & 0x0f) | 0x40 // version 4
b[8] = (b[8] & 0x3f) | 0x80 // RFC 4122 variant
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]), nil
}
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
// The nonce is used to prevent replay attacks and associate client sessions with ID tokens.
// Returns:
@@ -336,6 +351,7 @@ func createStringMap(keys []string) map[string]struct{} {
// and redirects to the provider's logout endpoint or configured post-logout URI.
// It handles potential errors during session retrieval or clearing.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing logout request")
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
+29
View File
@@ -0,0 +1,29 @@
package traefikoidc
import (
"regexp"
"testing"
)
// TestNewUUIDv4 verifies the in-house UUID v4 generator produces RFC 4122
// compliant identifiers. Locks in the replacement for github.com/google/uuid
// — a regression here would weaken the CSRF token used in the OIDC flow.
func TestNewUUIDv4(t *testing.T) {
rfc4122v4 := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
const samples = 1000
seen := make(map[string]struct{}, samples)
for i := 0; i < samples; i++ {
got, err := newUUIDv4()
if err != nil {
t.Fatalf("newUUIDv4 failed: %v", err)
}
if !rfc4122v4.MatchString(got) {
t.Fatalf("UUID %q does not match RFC 4122 v4 format", got)
}
if _, dup := seen[got]; dup {
t.Fatalf("duplicate UUID emitted within %d samples: %q", samples, got)
}
seen[got] = struct{}{}
}
}
+13 -5
View File
@@ -3,6 +3,7 @@ package traefikoidc
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
@@ -25,10 +26,16 @@ type HTTPClientConfig struct {
Timeout time.Duration
MaxConnsPerHost int
WriteBufferSize int
UseCookieJar bool
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
// RootCAs is an optional certificate pool used for TLS verification.
// A nil pool means "use the system trust store" (default behavior).
RootCAs *x509.CertPool
// InsecureSkipVerify disables TLS certificate verification.
// ONLY set this for local development against self-signed certificates.
InsecureSkipVerify bool
UseCookieJar bool
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
}
// DefaultHTTPClientConfig returns the default configuration for general use
@@ -203,7 +210,8 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
PreferServerCipherSuites: true,
InsecureSkipVerify: false, // Always verify certificates
RootCAs: config.RootCAs,
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
},
ForceAttemptHTTP2: config.ForceHTTP2,
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
+18 -3
View File
@@ -3,6 +3,7 @@ package traefikoidc
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"sync"
@@ -103,7 +104,8 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
PreferServerCipherSuites: true,
InsecureSkipVerify: false,
RootCAs: config.RootCAs,
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
},
ForceAttemptHTTP2: config.ForceHTTP2,
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
@@ -205,8 +207,21 @@ func (p *SharedTransportPool) performCleanup() {
// configKey generates a unique key for a config
func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
// Simple key based on main parameters
return string(rune(config.MaxConnsPerHost)) + string(rune(config.MaxIdleConnsPerHost))
// Pool transports by the parameters that change TLS or connection
// behavior. RootCAs and InsecureSkipVerify MUST be part of the key:
// otherwise a middleware configured with a custom CA would share a
// transport with one using the system store, silently bypassing its
// CA configuration.
skip := "0"
if config.InsecureSkipVerify {
skip = "1"
}
return fmt.Sprintf("%d|%d|%p|%s",
config.MaxConnsPerHost,
config.MaxIdleConnsPerHost,
config.RootCAs,
skip,
)
}
// Cleanup closes all transports and stops the cleanup goroutine
+14 -27
View File
@@ -10,6 +10,14 @@ import (
"unicode/utf8"
)
// Pre-compiled regex patterns for validation (const patterns should use MustCompile)
var (
emailRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
urlRegexPattern = regexp.MustCompile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
tokenRegexPattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)
usernameRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
)
// InputValidator provides comprehensive input validation and sanitization
// to protect against common security vulnerabilities including SQL injection,
// XSS, path traversal, and other injection attacks. It validates and sanitizes
@@ -73,7 +81,7 @@ func DefaultInputValidationConfig() InputValidationConfig {
}
// NewInputValidator creates a new input validator with the specified configuration.
// It compiles all necessary regex patterns and initializes security pattern lists.
// It uses pre-compiled regex patterns and initializes security pattern lists.
//
// Parameters:
// - config: Validation configuration with size limits and mode settings.
@@ -81,29 +89,8 @@ func DefaultInputValidationConfig() InputValidationConfig {
//
// Returns:
// - A configured InputValidator instance.
// - An error if regex compilation fails.
// - An error (always nil, kept for API compatibility).
func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputValidator, error) {
// Compile regex patterns
emailRegex, err := regexp.Compile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if err != nil {
return nil, fmt.Errorf("failed to compile email regex: %w", err)
}
urlRegex, err := regexp.Compile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
if err != nil {
return nil, fmt.Errorf("failed to compile URL regex: %w", err)
}
tokenRegex, err := regexp.Compile(`^[A-Za-z0-9._-]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile token regex: %w", err)
}
usernameRegex, err := regexp.Compile(`^[a-zA-Z0-9._-]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile username regex: %w", err)
}
return &InputValidator{
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
@@ -112,10 +99,10 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
emailRegex: emailRegexPattern,
urlRegex: urlRegexPattern,
tokenRegex: tokenRegexPattern,
usernameRegex: usernameRegexPattern,
sqlInjectionPatterns: []string{
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
"union", "select", "insert", "update", "delete", "drop",
+3
View File
@@ -24,6 +24,7 @@ type Config struct {
Type BackendType
RedisAddr string
RedisPassword string
TLSServerName string
PoolSize int
RedisDB int
CleanupInterval time.Duration
@@ -34,6 +35,8 @@ type Config struct {
EnableCircuitBreaker bool
EnableHealthCheck bool
EnableMetrics bool
EnableTLS bool
TLSSkipVerify bool
}
// DefaultConfig returns a default configuration for in-memory caching
+73 -26
View File
@@ -20,6 +20,7 @@ type HybridBackend struct {
ctx context.Context
syncWriteCacheTypes map[string]bool
asyncWriteBuffer chan *asyncWriteItem
l1BackfillBuffer chan *l1BackfillItem
cancel context.CancelFunc
wg sync.WaitGroup
l1Hits atomic.Int64
@@ -28,6 +29,7 @@ type HybridBackend struct {
l1Writes atomic.Int64
misses atomic.Int64
l2Hits atomic.Int64
l1BackfillDrops atomic.Int64
fallbackMode atomic.Bool
}
@@ -39,6 +41,15 @@ type asyncWriteItem struct {
ttl time.Duration
}
// l1BackfillItem represents a deferred write of an L2-resolved value back into
// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot
// detonate the goroutine count (issue: ~1000% CPU under sustained polling).
type l1BackfillItem struct {
key string
value []byte
ttl time.Duration
}
// Logger interface for structured logging
type Logger interface {
Debugf(format string, args ...interface{})
@@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
secondary: config.Secondary,
syncWriteCacheTypes: config.SyncWriteCacheTypes,
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize),
ctx: ctx,
cancel: cancel,
logger: config.Logger,
@@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
h.wg.Add(1)
go h.asyncWriteWorker()
// Start L1 backfill worker (single goroutine) to bound goroutine growth on
// L2 hits regardless of request rate.
h.wg.Add(1)
go h.l1BackfillWorker()
// Start health monitoring
h.wg.Add(1)
go h.healthMonitor()
@@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
h.l2Hits.Add(1)
// Populate L1 cache with value from L2 (write-through on read)
// Use goroutine to avoid blocking the read path
go func() {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
}
}()
// Populate L1 cache with value from L2 (write-through on read).
// Hand off to the bounded backfill worker instead of spawning a goroutine
// per read - under burst that would mint thousands of goroutines.
h.queueL1Backfill(key, value, ttl)
return value, ttl, true, nil
}
@@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error {
// Close async write channel
close(h.asyncWriteBuffer)
close(h.l1BackfillBuffer)
// Wait for workers to finish with timeout
done := make(chan struct{})
@@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
for key, value := range l2Results {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
}(key, value)
h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL
}
}
} else {
@@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte, t time.Duration) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, t)
}(key, value, ttl)
h.queueL1Backfill(key, value, ttl)
}
}
}
@@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt
return nil
}
// queueL1Backfill enqueues an L2-resolved value for write-through into L1.
// Drops on full buffer to keep the read path constant-time; the next L2 hit
// for the same key simply re-queues it.
func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) {
select {
case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}:
default:
h.l1BackfillDrops.Add(1)
h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", key)
}
}
// l1BackfillWorker drains the backfill queue serially. Single worker is
// intentional - L1 writes are local and cheap, and serializing them keeps
// goroutine count bounded under any read rate.
func (h *HybridBackend) l1BackfillWorker() {
defer h.wg.Done()
for {
select {
case <-h.ctx.Done():
// Drain remaining items best-effort then exit.
for len(h.l1BackfillBuffer) > 0 {
select {
case item := <-h.l1BackfillBuffer:
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_ = h.primary.Set(writeCtx, item.key, item.value, item.ttl)
cancel()
default:
return
}
}
return
case item, ok := <-h.l1BackfillBuffer:
if !ok {
return
}
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", item.key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", item.key)
}
cancel()
}
}
}
// asyncWriteWorker processes asynchronous writes to L2
func (h *HybridBackend) asyncWriteWorker() {
defer h.wg.Done()
+112
View File
@@ -0,0 +1,112 @@
//go:build !yaegi
package backends
import (
"context"
"fmt"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does
// not detonate the goroutine count. Pre-fix the code spawned one goroutine
// per Get() L2 hit; post-fix all backfills funnel through a single worker.
func TestHybridBackend_L1BackfillBounded(t *testing.T) {
primary := newMockBackend()
secondary := newMockBackend()
hybrid, err := NewHybridBackend(&HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 256,
})
require.NoError(t, err)
defer hybrid.Close()
ctx := context.Background()
const burst = 1000
// Pre-populate L2 with `burst` distinct keys so each Get triggers a
// fresh L1 backfill enqueue.
for i := 0; i < burst; i++ {
require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute))
}
baseline := runtime.NumGoroutine()
// Issue the burst as fast as possible; the backfill worker MUST be the
// only goroutine doing L1 writes. Allow brief slack for the test runtime
// scheduling but anything north of +20 means goroutine leakage.
peak := baseline
for i := 0; i < burst; i++ {
_, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i))
require.NoError(t, err)
require.True(t, exists)
if g := runtime.NumGoroutine(); g > peak {
peak = g
}
}
delta := peak - baseline
if delta > 20 {
t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines",
delta, baseline, peak)
}
// L1 must eventually catch up via the worker. Worker drains serially so
// give it a generous window proportional to the burst size.
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
var populated int
for i := 0; i < burst; i++ {
if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok {
populated++
}
}
// Be lenient: drops are acceptable under buffer pressure, just want
// most of the keys to make it.
if populated >= burst-int(hybrid.l1BackfillDrops.Load()) {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d",
hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load())
}
// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the
// buffer is saturated. Drops must be counted, never block, never spawn a
// goroutine.
func TestHybridBackend_L1BackfillFullDrops(t *testing.T) {
primary := newMockBackend()
secondary := newMockBackend()
// Tiny buffer + slow primary writes via failSet so the worker stays
// blocked enough to overflow the buffer.
hybrid, err := NewHybridBackend(&HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 4,
})
require.NoError(t, err)
defer hybrid.Close()
// Stop the worker from draining: cancel the underlying context so the
// worker bails out, leaving us with a cold buffer and the queue method
// itself responsible for drop accounting.
hybrid.cancel()
// Wait for worker to exit so it can't drain.
time.Sleep(50 * time.Millisecond)
for i := 0; i < 50; i++ {
hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)
}
assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0),
"expected some drops when buffer is saturated and worker is stopped")
}
+224 -197
View File
@@ -2,20 +2,27 @@
package backends
import (
"container/list"
"context"
"sync"
"sync/atomic"
"time"
)
// Default configuration values
const (
defaultShardCount = 256
defaultMaxSize = int64(10000)
defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB
defaultCleanupInterval = 5 * time.Minute
)
// memoryCacheItem represents an item in the memory cache
type memoryCacheItem struct {
expiresAt time.Time
createdAt time.Time
accessedAt time.Time
value interface{}
element *list.Element
element interface{} // *list.Element, using interface{} to avoid import cycle
key string
accessCount int64
size int64
@@ -29,56 +36,89 @@ func (item *memoryCacheItem) isExpired() bool {
return time.Now().After(item.expiresAt)
}
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage
// The sharded design reduces lock contention by partitioning keys across multiple shards,
// each with its own lock.
type MemoryCacheBackend struct {
shards []*cacheShard
startTime time.Time
lastErrorTime time.Time
items map[string]*memoryCacheItem
lruList *list.List
cleanupDone chan bool
cleanupDone chan struct{}
cleanupTicker *time.Ticker
evictionPolicy string
lastError string
currentMemory int64
misses atomic.Int64
deletes atomic.Int64
evictions atomic.Int64
errors atomic.Int64
totalGetTime atomic.Int64
totalSetTime atomic.Int64
getCount atomic.Int64
setCount atomic.Int64
sets atomic.Int64
hits atomic.Int64
shardCount uint32
shardMask uint32
maxSize int64
currentSize int64
maxMemory int64
cleanupInterval time.Duration
mu sync.RWMutex
closed atomic.Bool
// Global stats (aggregated from shards)
hits atomic.Int64
misses atomic.Int64
sets atomic.Int64
deletes atomic.Int64
evictions atomic.Int64
errors atomic.Int64
// Latency tracking
totalGetTime atomic.Int64
totalSetTime atomic.Int64
getCount atomic.Int64
setCount atomic.Int64
// State
closed atomic.Bool
mu sync.RWMutex // For global operations like stats and error tracking
}
// NewMemoryCacheBackend creates a new memory cache backend
// NewMemoryCacheBackend creates a new sharded memory cache backend
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
if maxSize <= 0 {
maxSize = 10000 // Default to 10k items
maxSize = defaultMaxSize
}
if maxMemory <= 0 {
maxMemory = 100 * 1024 * 1024 // Default to 100MB
maxMemory = defaultMaxMemory
}
if cleanupInterval <= 0 {
cleanupInterval = 5 * time.Minute
cleanupInterval = defaultCleanupInterval
}
shardCount := uint32(defaultShardCount)
// For very small caches, reduce shard count to maintain sensible per-shard limits
// Ensure each shard can hold at least 2 items for proper LRU behavior
for shardCount > 1 && maxSize/int64(shardCount) < 2 {
shardCount /= 2
}
if shardCount < 1 {
shardCount = 1
}
// Per-shard limits are soft hints; global limits are enforced
// Give shards 2x the average to allow for uneven distribution
shardMaxSize := (maxSize * 2) / int64(shardCount)
if shardMaxSize < 4 {
shardMaxSize = 4
}
shardMaxMemory := (maxMemory * 2) / int64(shardCount)
if shardMaxMemory < 4096 {
shardMaxMemory = 4096 // Minimum 4KB per shard
}
m := &MemoryCacheBackend{
items: make(map[string]*memoryCacheItem),
lruList: list.New(),
shards: make([]*cacheShard, shardCount),
shardCount: shardCount,
shardMask: shardCount - 1, // For fast modulo with power-of-2
maxSize: maxSize,
maxMemory: maxMemory,
startTime: time.Now(),
cleanupInterval: cleanupInterval,
evictionPolicy: "lru",
cleanupDone: make(chan bool),
cleanupDone: make(chan struct{}),
}
// Initialize shards
for i := uint32(0); i < shardCount; i++ {
m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory)
}
// Start cleanup goroutine
@@ -88,6 +128,12 @@ func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.
return m
}
// getShard returns the shard for a given key
func (m *MemoryCacheBackend) getShard(key string) *cacheShard {
hash := fnv32(key)
return m.shards[hash&m.shardMask]
}
// cleanupLoop runs periodic cleanup of expired items
func (m *MemoryCacheBackend) cleanupLoop() {
for {
@@ -100,20 +146,19 @@ func (m *MemoryCacheBackend) cleanupLoop() {
}
}
// cleanupExpired removes all expired items from the cache
// cleanupExpired removes all expired items from all shards
func (m *MemoryCacheBackend) cleanupExpired() {
m.mu.Lock()
defer m.mu.Unlock()
var keysToDelete []string
for key, item := range m.items {
if item.isExpired() {
keysToDelete = append(keysToDelete, key)
}
if m.closed.Load() {
return
}
for _, key := range keysToDelete {
m.deleteItemLocked(key)
totalRemoved := 0
for _, shard := range m.shards {
totalRemoved += shard.cleanup()
}
if totalRemoved > 0 {
m.evictions.Add(int64(totalRemoved))
}
}
@@ -130,35 +175,23 @@ func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{},
m.getCount.Add(1)
}()
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
shard := m.getShard(key)
value, exists, expired := shard.get(key)
if expired {
// Clean up expired item
shard.delete(key)
m.misses.Add(1)
return nil, ErrCacheMiss
}
if !exists {
m.misses.Add(1)
return nil, ErrCacheMiss
}
if item.isExpired() {
m.mu.Lock()
m.deleteItemLocked(key)
m.mu.Unlock()
m.misses.Add(1)
return nil, ErrCacheMiss
}
// Update access time and count
m.mu.Lock()
item.accessedAt = time.Now()
item.accessCount++
// Move to front of LRU list
if m.evictionPolicy == "lru" && item.element != nil {
m.lruList.MoveToFront(item.element)
}
m.mu.Unlock()
m.hits.Add(1)
return item.value, nil
return value, nil
}
// Set stores a value in the cache with optional TTL
@@ -174,113 +207,105 @@ func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interfac
m.setCount.Add(1)
}()
// Calculate item size (simplified estimation)
// Calculate item size
itemSize := int64(len(key)) + estimateValueSize(value)
m.mu.Lock()
defer m.mu.Unlock()
// Enforce global limits before adding new item
m.enforceGlobalLimits(itemSize)
// Check if we need to evict items
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
m.evictLocked()
}
// Check if key exists
if oldItem, exists := m.items[key]; exists {
m.currentMemory -= oldItem.size
if oldItem.element != nil {
m.lruList.Remove(oldItem.element)
}
} else {
m.currentSize++
}
now := time.Now()
var expiresAt time.Time
if ttl > 0 {
expiresAt = now.Add(ttl)
expiresAt = time.Now().Add(ttl)
}
item := &memoryCacheItem{
key: key,
value: value,
expiresAt: expiresAt,
createdAt: now,
accessedAt: now,
accessCount: 0,
size: itemSize,
}
shard := m.getShard(key)
shard.set(key, value, expiresAt, itemSize)
// Add to LRU list
if m.evictionPolicy == "lru" {
item.element = m.lruList.PushFront(item)
}
m.items[key] = item
m.currentMemory += itemSize
m.sets.Add(1)
return nil
}
// enforceGlobalLimits ensures global size and memory limits are respected
// by evicting from shards when necessary
func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) {
// Check and enforce size limit
for {
totalSize, totalMemory := m.getGlobalStats()
needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize
needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory
if !needsSizeEviction && !needsMemoryEviction {
break
}
// Find the shard with the most items and evict from it
evicted := m.evictFromLargestShard()
if !evicted {
break // No more items to evict
}
m.evictions.Add(1)
}
}
// getGlobalStats returns the total size and memory usage across all shards
func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) {
for _, shard := range m.shards {
size, memory := shard.stats()
totalSize += size
totalMemory += memory
}
return
}
// evictFromLargestShard evicts the globally oldest item across all shards
// This provides true LRU behavior even with sharding
func (m *MemoryCacheBackend) evictFromLargestShard() bool {
var oldestShard *cacheShard
var oldestTime time.Time
for _, shard := range m.shards {
accessTime := shard.getOldestAccessTime()
// Skip empty shards
if accessTime.IsZero() {
continue
}
// Find the shard with the oldest (earliest) access time
if oldestShard == nil || accessTime.Before(oldestTime) {
oldestTime = accessTime
oldestShard = shard
}
}
if oldestShard == nil {
return false
}
return oldestShard.evictOne()
}
// Delete removes a key from the cache
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.items[key]; !exists {
return nil
shard := m.getShard(key)
if shard.delete(key) {
m.deletes.Add(1)
}
m.deleteItemLocked(key)
m.deletes.Add(1)
return nil
}
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
if item, exists := m.items[key]; exists {
m.currentMemory -= item.size
m.currentSize--
if item.element != nil {
m.lruList.Remove(item.element)
}
delete(m.items, key)
}
}
// evictLocked evicts items based on the eviction policy (must be called with lock held)
func (m *MemoryCacheBackend) evictLocked() {
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
// Evict least recently used item
element := m.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
m.deleteItemLocked(item.key)
m.evictions.Add(1)
}
}
}
// Exists checks if a key exists in the cache
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
if m.closed.Load() {
return false, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists {
return false, nil
}
return !item.isExpired(), nil
shard := m.getShard(key)
return shard.exists(key), nil
}
// Clear removes all items from the cache
@@ -289,13 +314,9 @@ func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryCacheItem)
m.lruList = list.New()
m.currentSize = 0
m.currentMemory = 0
for _, shard := range m.shards {
shard.clear()
}
return nil
}
@@ -306,29 +327,28 @@ func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string
return nil, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
var keys []string
for key, item := range m.items {
if !item.isExpired() && matchPattern(pattern, key) {
keys = append(keys, key)
}
var allKeys []string
for _, shard := range m.shards {
keys := shard.keys(pattern)
allKeys = append(allKeys, keys...)
}
return keys, nil
return allKeys, nil
}
// Size returns the number of items in the cache
// Size returns the total number of items in the cache
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
var total int64
for _, shard := range m.shards {
size, _ := shard.stats()
total += size
}
return m.currentSize, nil
return total, nil
}
// TTL returns the remaining time-to-live for a key
@@ -337,24 +357,13 @@ func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration
return 0, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists || item.isExpired() {
shard := m.getShard(key)
ttl, exists := shard.ttl(key)
if !exists {
return 0, ErrCacheMiss
}
if item.expiresAt.IsZero() {
return 0, nil // No expiration
}
remaining := time.Until(item.expiresAt)
if remaining < 0 {
return 0, nil
}
return remaining, nil
return ttl, nil
}
// Expire updates the TTL for an existing key
@@ -363,20 +372,11 @@ func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Du
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
item, exists := m.items[key]
if !exists || item.isExpired() {
shard := m.getShard(key)
if !shard.expire(key, ttl) {
return ErrCacheMiss
}
if ttl > 0 {
item.expiresAt = time.Now().Add(ttl)
} else {
item.expiresAt = time.Time{} // Remove expiration
}
return nil
}
@@ -386,6 +386,14 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
return nil, ErrBackendUnavailable
}
// Aggregate stats from all shards
var totalSize, totalMemory int64
for _, shard := range m.shards {
size, memory := shard.stats()
totalSize += size
totalMemory += memory
}
m.mu.RLock()
lastError := m.lastError
lastErrorTime := m.lastErrorTime
@@ -409,9 +417,9 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
Deletes: m.deletes.Load(),
Errors: m.errors.Load(),
Evictions: m.evictions.Load(),
CurrentSize: m.currentSize,
CurrentSize: totalSize,
MaxSize: m.maxSize,
MemoryUsage: m.currentMemory,
MemoryUsage: totalMemory,
AverageGetLatency: avgGetLatency,
AverageSetLatency: avgSetLatency,
LastError: lastError,
@@ -438,10 +446,10 @@ func (m *MemoryCacheBackend) Close() error {
m.cleanupTicker.Stop()
close(m.cleanupDone)
m.mu.Lock()
m.items = nil
m.lruList = nil
m.mu.Unlock()
// Clear all shards
for _, shard := range m.shards {
shard.clear()
}
return nil
}
@@ -474,12 +482,28 @@ func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
}
}
// GetShardCount returns the number of shards (for testing/monitoring)
func (m *MemoryCacheBackend) GetShardCount() uint32 {
return m.shardCount
}
// GetShardStats returns per-shard statistics (for monitoring)
func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 {
stats := make([]map[string]int64, m.shardCount)
for i, shard := range m.shards {
size, memory := shard.stats()
stats[i] = map[string]int64{
"size": size,
"memory": memory,
}
}
return stats
}
// Helper functions
// estimateValueSize estimates the size of a value in bytes
func estimateValueSize(value interface{}) int64 {
// This is a simplified estimation
// In production, you might want to use a more accurate method
switch v := value.(type) {
case string:
return int64(len(v))
@@ -502,7 +526,10 @@ func matchPattern(pattern, key string) bool {
if pattern == "*" {
return true
}
// Simplified pattern matching - in production, use a proper glob library
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
// Simplified pattern matching
if len(pattern) > 0 && pattern[0] == '*' {
suffix := pattern[1:]
return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix
}
return key == pattern
}
+294
View File
@@ -0,0 +1,294 @@
package backends
import (
"container/list"
"sync"
"time"
)
// cacheShard represents a single shard of the sharded cache
// Each shard has its own lock for reduced contention
type cacheShard struct {
items map[string]*memoryCacheItem
lruList *list.List
mu sync.RWMutex
maxSize int64
maxMemory int64
size int64
memoryUsed int64
}
// newCacheShard creates a new cache shard
func newCacheShard(maxSize, maxMemory int64) *cacheShard {
return &cacheShard{
items: make(map[string]*memoryCacheItem),
lruList: list.New(),
maxSize: maxSize,
maxMemory: maxMemory,
}
}
// get retrieves a value from this shard
// Returns: value, exists, expired
func (s *cacheShard) get(key string) (interface{}, bool, bool) {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists {
return nil, false, false
}
if item.isExpired() {
return nil, true, true // exists but expired
}
// Update access time and LRU position under write lock
s.mu.Lock()
// Re-check item exists (could have been deleted)
item, exists = s.items[key]
if exists && !item.isExpired() {
item.accessedAt = time.Now()
item.accessCount++
if elem, ok := item.element.(*list.Element); ok && elem != nil {
s.lruList.MoveToFront(elem)
}
}
s.mu.Unlock()
if !exists || item.isExpired() {
return nil, false, false
}
return item.value, true, false
}
// set stores a value in this shard
func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) {
s.mu.Lock()
defer s.mu.Unlock()
// Check if we need to evict items
if s.maxSize > 0 && s.size >= s.maxSize {
s.evictLRULocked()
}
if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory {
s.evictLRULocked()
}
// Remove old item if exists
if oldItem, exists := s.items[key]; exists {
s.memoryUsed -= oldItem.size
if elem, ok := oldItem.element.(*list.Element); ok && elem != nil {
s.lruList.Remove(elem)
}
s.size--
}
now := time.Now()
item := &memoryCacheItem{
key: key,
value: value,
expiresAt: expiresAt,
createdAt: now,
accessedAt: now,
accessCount: 0,
size: size,
}
item.element = s.lruList.PushFront(item)
s.items[key] = item
s.size++
s.memoryUsed += size
}
// delete removes a key from this shard
// Returns true if the key was deleted
func (s *cacheShard) delete(key string) bool {
s.mu.Lock()
defer s.mu.Unlock()
item, exists := s.items[key]
if !exists {
return false
}
s.deleteItemLocked(item)
return true
}
// exists checks if a key exists (and is not expired)
func (s *cacheShard) exists(key string) bool {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists {
return false
}
return !item.isExpired()
}
// ttl returns the remaining TTL for a key
func (s *cacheShard) ttl(key string) (time.Duration, bool) {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists || item.isExpired() {
return 0, false
}
if item.expiresAt.IsZero() {
return 0, true // No expiration
}
remaining := time.Until(item.expiresAt)
if remaining < 0 {
return 0, false
}
return remaining, true
}
// expire updates the TTL for an existing key
func (s *cacheShard) expire(key string, ttl time.Duration) bool {
s.mu.Lock()
defer s.mu.Unlock()
item, exists := s.items[key]
if !exists || item.isExpired() {
return false
}
if ttl > 0 {
item.expiresAt = time.Now().Add(ttl)
} else {
item.expiresAt = time.Time{} // Remove expiration
}
return true
}
// keys returns all non-expired keys matching the pattern
func (s *cacheShard) keys(pattern string) []string {
s.mu.RLock()
defer s.mu.RUnlock()
var keys []string
for key, item := range s.items {
if !item.isExpired() && matchPattern(pattern, key) {
keys = append(keys, key)
}
}
return keys
}
// clear removes all items from this shard
func (s *cacheShard) clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.items = make(map[string]*memoryCacheItem)
s.lruList.Init()
s.size = 0
s.memoryUsed = 0
}
// cleanup removes expired items
// Returns the number of items removed
func (s *cacheShard) cleanup() int {
s.mu.Lock()
defer s.mu.Unlock()
var toRemove []*memoryCacheItem
for _, item := range s.items {
if item.isExpired() {
toRemove = append(toRemove, item)
}
}
for _, item := range toRemove {
s.deleteItemLocked(item)
}
return len(toRemove)
}
// stats returns statistics for this shard
func (s *cacheShard) stats() (size, memory int64) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.size, s.memoryUsed
}
// deleteItemLocked removes an item (must be called with lock held)
func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) {
if elem, ok := item.element.(*list.Element); ok && elem != nil {
s.lruList.Remove(elem)
}
delete(s.items, item.key)
s.size--
s.memoryUsed -= item.size
}
// evictLRULocked evicts the least recently used item (must be called with lock held)
func (s *cacheShard) evictLRULocked() bool {
if s.lruList.Len() == 0 {
return false
}
element := s.lruList.Back()
if element != nil {
item, ok := element.Value.(*memoryCacheItem)
if ok {
s.deleteItemLocked(item)
return true
}
}
return false
}
// evictOne evicts one item from this shard (for global limit enforcement)
func (s *cacheShard) evictOne() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.evictLRULocked()
}
// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard
// Returns zero time if shard is empty
func (s *cacheShard) getOldestAccessTime() time.Time {
s.mu.RLock()
defer s.mu.RUnlock()
if s.lruList.Len() == 0 {
return time.Time{}
}
element := s.lruList.Back()
if element != nil {
item, ok := element.Value.(*memoryCacheItem)
if ok {
return item.accessedAt
}
}
return time.Time{}
}
// fnv32 computes FNV-1a hash of a string
// This is a fast, well-distributed hash function
func fnv32(key string) uint32 {
const (
offset32 = uint32(2166136261)
prime32 = uint32(16777619)
)
hash := offset32
for i := 0; i < len(key); i++ {
hash ^= uint32(key[i])
hash *= prime32
}
return hash
}
+283
View File
@@ -0,0 +1,283 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestShardedCache_ShardDistribution tests that keys are distributed across shards
func TestShardedCache_ShardDistribution(t *testing.T) {
t.Parallel()
// Create a cache with large enough size to have multiple shards
config := DefaultConfig()
config.MaxSize = 10000
config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add many items to see distribution
numItems := 1000
for i := 0; i < numItems; i++ {
key := fmt.Sprintf("dist-key-%d", i)
value := []byte(fmt.Sprintf("dist-value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Check that items are distributed across multiple shards
shardStats := backend.MemoryCacheBackend.GetShardStats()
nonEmptyShards := 0
for _, stat := range shardStats {
if stat["size"] > 0 {
nonEmptyShards++
}
}
// With good hash distribution, we should have items in multiple shards
assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards")
}
// TestShardedCache_ShardCount tests that shard count adapts to cache size
func TestShardedCache_ShardCount(t *testing.T) {
t.Parallel()
tests := []struct {
maxSize int
expectLowShards bool
}{
{5, true}, // Very small cache should have fewer shards
{100, true}, // Small cache should have fewer shards
{10000, false}, // Large cache should have default shards
}
for _, tt := range tests {
t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) {
config := DefaultConfig()
config.MaxSize = tt.maxSize
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
shardCount := backend.MemoryCacheBackend.GetShardCount()
if tt.expectLowShards {
assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards")
} else {
assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count")
}
})
}
}
// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key
func TestShardedCache_ConcurrentSameKey(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "concurrent-same-key"
initialValue := []byte("initial-value")
err = backend.Set(ctx, key, initialValue, time.Minute)
require.NoError(t, err)
var wg sync.WaitGroup
goroutines := 50
iterations := 100
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
// Mix of reads and writes
if j%3 == 0 {
newValue := []byte(fmt.Sprintf("value-%d-%d", id, j))
err := backend.Set(ctx, key, newValue, time.Minute)
assert.NoError(t, err)
} else {
_, _, _, err := backend.Get(ctx, key)
assert.NoError(t, err)
}
}
}(i)
}
wg.Wait()
// Key should still exist
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
}
// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained
func TestShardedCache_GlobalLRUEviction(t *testing.T) {
t.Parallel()
// Create a small cache to force eviction
config := DefaultConfig()
config.MaxSize = 10
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
// Small delay to ensure different access times
time.Sleep(time.Millisecond)
}
// Access some items to make them recently used
for i := 5; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
_, _, _, err := backend.Get(ctx, key)
require.NoError(t, err)
}
// Add more items to trigger eviction
for i := 10; i < 15; i++ {
key := fmt.Sprintf("global-lru-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Recently accessed items (5-9) should still exist
for i := 5; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Recently accessed item %d should exist", i)
}
// Check eviction stats
stats := backend.GetStats()
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have evictions")
}
// TestShardedCache_StatsAggregation tests that stats are aggregated correctly
func TestShardedCache_StatsAggregation(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 10000
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items to multiple shards
numItems := 100
for i := 0; i < numItems; i++ {
key := fmt.Sprintf("stats-key-%d", i)
value := []byte(fmt.Sprintf("stats-value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Read some items
for i := 0; i < numItems/2; i++ {
key := fmt.Sprintf("stats-key-%d", i)
backend.Get(ctx, key)
}
// Read non-existent items
for i := 0; i < 10; i++ {
backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i))
}
stats := backend.GetStats()
// Verify stats
assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match")
assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match")
assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match")
assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match")
// Verify hit rate
hitRate := stats["hit_rate"].(float64)
expectedHitRate := float64(numItems/2) / float64(numItems/2+10)
assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match")
}
// BenchmarkShardedCache_Parallel benchmarks parallel access
func BenchmarkShardedCache_Parallel(b *testing.B) {
config := DefaultConfig()
config.MaxSize = 100000
config.MaxMemoryBytes = 100 * 1024 * 1024
backend, _ := NewMemoryBackend(config)
defer backend.Close()
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 10000; i++ {
key := fmt.Sprintf("bench-key-%d", i)
value := []byte(fmt.Sprintf("bench-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("bench-key-%d", i%10000)
backend.Get(ctx, key)
i++
}
})
}
// BenchmarkShardedCache_MixedOps benchmarks mixed operations
func BenchmarkShardedCache_MixedOps(b *testing.B) {
config := DefaultConfig()
config.MaxSize = 100000
config.MaxMemoryBytes = 100 * 1024 * 1024
backend, _ := NewMemoryBackend(config)
defer backend.Close()
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("mixed-key-%d", i%1000)
if i%3 == 0 {
value := []byte(fmt.Sprintf("mixed-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
} else {
backend.Get(ctx, key)
}
i++
}
})
}
+20 -30
View File
@@ -45,21 +45,11 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
return nil, 0, false, err
}
// Get the item directly to check TTL
m.MemoryCacheBackend.mu.RLock()
item, exists := m.MemoryCacheBackend.items[key]
m.MemoryCacheBackend.mu.RUnlock()
if !exists {
return nil, 0, false, nil
}
var ttl time.Duration
if !item.expiresAt.IsZero() {
ttl = time.Until(item.expiresAt)
if ttl < 0 {
ttl = 0
}
// Get TTL using the TTL method
ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key)
if ttlErr != nil {
// If we can't get TTL, still return the value with 0 TTL
ttl = 0
}
// Convert interface{} to []byte
@@ -68,8 +58,7 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
if bytes, ok := val.([]byte); ok {
valueBytes = bytes
} else {
// If it's not already []byte, we might need to handle other types
// For now, we'll just return an error
// If it's not already []byte, return an error
return nil, 0, false, ErrInvalidValue
}
}
@@ -123,19 +112,20 @@ func (m *MemoryBackend) GetStats() map[string]interface{} {
}
return map[string]interface{}{
"type": stats.Type,
"hits": stats.Hits,
"misses": stats.Misses,
"sets": stats.Sets,
"deletes": stats.Deletes,
"errors": stats.Errors,
"evictions": stats.Evictions,
"size": stats.CurrentSize,
"max_size": stats.MaxSize,
"memory": stats.MemoryUsage,
"hit_rate": hitRate,
"uptime": stats.Uptime,
"start_time": stats.StartTime,
"type": stats.Type,
"hits": stats.Hits,
"misses": stats.Misses,
"sets": stats.Sets,
"deletes": stats.Deletes,
"errors": stats.Errors,
"evictions": stats.Evictions,
"size": stats.CurrentSize,
"max_size": stats.MaxSize,
"memory": stats.MemoryUsage,
"hit_rate": hitRate,
"uptime": stats.Uptime,
"start_time": stats.StartTime,
"shard_count": m.MemoryCacheBackend.GetShardCount(),
}
}
+112 -13
View File
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
poolConfig := &PoolConfig{
Address: config.RedisAddr,
Password: config.RedisPassword,
TLSServerName: config.TLSServerName,
DB: config.RedisDB,
MaxConnections: config.PoolSize,
ConnectTimeout: 2 * time.Second,
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
EnableHealthCheck: true,
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
EnableTLS: config.EnableTLS,
TLSSkipVerify: config.TLSSkipVerify,
}
pool, err := NewConnectionPool(poolConfig)
@@ -345,7 +348,7 @@ func (r *RedisBackend) prefixKey(key string) string {
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
// It checks context cancellation at multiple points to ensure fast abort when the
// caller's context is cancelled (e.g., due to request timeout).
// caller's context is canceled (e.g., due to request timeout).
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
maxRetries := 3
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
@@ -377,7 +380,7 @@ func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*Red
err = operation(conn)
r.pool.Put(conn)
// Check context after operation - if cancelled, don't bother retrying
// Check context after operation - if canceled, don't bother retrying
if ctx.Err() != nil {
return ctx.Err()
}
@@ -431,39 +434,135 @@ func isRetryableError(err error) bool {
return false
}
// SetMany stores multiple values in Redis (batch operation)
// SetMany stores multiple values in Redis using pipelining for efficiency
// This reduces N round-trips to a single round-trip
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
// For simplicity, execute sequentially (can be optimized with pipelining later)
for key, value := range items {
if err := r.Set(ctx, key, value, ttl); err != nil {
return err
if len(items) == 0 {
return nil
}
// For single items, use regular Set
if len(items) == 1 {
for key, value := range items {
return r.Set(ctx, key, value, ttl)
}
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
pipeline := conn.NewPipeline()
// Queue all SET commands
ttlSeconds := int(ttl.Seconds())
ttlMillis := ttl.Milliseconds()
for key, value := range items {
prefixedKey := r.prefixKey(key)
if ttl > 0 {
if ttlMillis < 1000 {
// Use PSETEX for sub-second TTLs
pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
} else {
// Use SETEX for larger TTLs
pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
}
} else {
pipeline.Queue("SET", prefixedKey, string(value))
}
}
// Execute pipeline
responses, err := pipeline.Execute()
if err != nil {
return fmt.Errorf("pipeline SetMany failed: %w", err)
}
// Check responses for errors (each should be "OK")
for i, resp := range responses {
if resp == nil {
continue
}
if str, ok := resp.(string); ok && str == "OK" {
continue
}
return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp)
}
return nil
}
// GetMany retrieves multiple values from Redis
// GetMany retrieves multiple values from Redis using pipelining for efficiency
// This reduces N round-trips to a single round-trip
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if r.closed.Load() {
return nil, ErrBackendClosed
}
result := make(map[string][]byte)
if len(keys) == 0 {
return make(map[string][]byte), nil
}
// For simplicity, execute sequentially
for _, key := range keys {
value, _, exists, err := r.Get(ctx, key)
// For single key, use regular Get
if len(keys) == 1 {
result := make(map[string][]byte)
value, _, exists, err := r.Get(ctx, keys[0])
if err != nil {
return nil, err
}
if exists {
result[key] = value
result[keys[0]] = value
}
return result, nil
}
conn, err := r.pool.Get(ctx)
if err != nil {
return nil, err
}
defer r.pool.Put(conn)
pipeline := conn.NewPipeline()
// Queue all GET commands
prefixedKeys := make([]string, len(keys))
for i, key := range keys {
prefixedKeys[i] = r.prefixKey(key)
pipeline.Queue("GET", prefixedKeys[i])
}
// Execute pipeline
responses, err := pipeline.Execute()
if err != nil {
return nil, fmt.Errorf("pipeline GetMany failed: %w", err)
}
// Process responses
result := make(map[string][]byte)
for i, resp := range responses {
if resp == nil {
// Key doesn't exist
r.misses.Add(1)
continue
}
value, err := RESPString(resp)
if err != nil {
// Invalid response, skip this key
r.misses.Add(1)
continue
}
r.hits.Add(1)
result[keys[i]] = []byte(value)
}
return result, nil
+461
View File
@@ -0,0 +1,461 @@
package backends
import (
"context"
"fmt"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// setupTestRedis creates a miniredis instance for testing
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) {
t.Helper()
mr, err := miniredis.Run()
require.NoError(t, err)
t.Cleanup(func() {
mr.Close()
})
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "test:",
PoolSize: 5,
})
require.NoError(t, err)
t.Cleanup(func() {
backend.Close()
})
return mr, backend
}
// TestPipeline_Basic tests basic pipeline functionality
func TestPipeline_Basic(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
config := &PoolConfig{
Address: mr.Addr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
defer pool.Put(conn)
t.Run("SingleCommand", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("SET", "single-key", "single-value")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 1)
assert.Equal(t, "OK", responses[0])
})
t.Run("MultipleCommands", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("SET", "key1", "value1")
pipeline.Queue("SET", "key2", "value2")
pipeline.Queue("SET", "key3", "value3")
pipeline.Queue("GET", "key1")
pipeline.Queue("GET", "key2")
pipeline.Queue("GET", "key3")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 6)
// First 3 are SET responses
assert.Equal(t, "OK", responses[0])
assert.Equal(t, "OK", responses[1])
assert.Equal(t, "OK", responses[2])
// Last 3 are GET responses
assert.Equal(t, "value1", responses[3])
assert.Equal(t, "value2", responses[4])
assert.Equal(t, "value3", responses[5])
})
t.Run("EmptyPipeline", func(t *testing.T) {
pipeline := conn.NewPipeline()
responses, err := pipeline.Execute()
require.NoError(t, err)
assert.Nil(t, responses)
})
t.Run("NilResponses", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("GET", "nonexistent-key")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 1)
assert.Nil(t, responses[0])
})
}
// TestPipeline_SetMany tests pipelined SetMany
func TestPipeline_SetMany(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
t.Run("SetManyItems", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 10; i++ {
items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Verify all items were set
for key, expectedValue := range items {
value, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key %s should exist", key)
assert.Equal(t, expectedValue, value)
}
})
t.Run("SetManyEmpty", func(t *testing.T) {
err := backend.SetMany(ctx, map[string][]byte{}, time.Minute)
require.NoError(t, err)
})
t.Run("SetManySingleItem", func(t *testing.T) {
items := map[string][]byte{
"single-setmany": []byte("single-value"),
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
value, _, exists, err := backend.Get(ctx, "single-setmany")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("single-value"), value)
})
t.Run("SetManyNoTTL", func(t *testing.T) {
items := map[string][]byte{
"nottl-key1": []byte("value1"),
"nottl-key2": []byte("value2"),
}
err := backend.SetMany(ctx, items, 0)
require.NoError(t, err)
// Keys should exist
for key := range items {
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
}
})
}
// TestPipeline_GetMany tests pipelined GetMany
func TestPipeline_GetMany(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 10; i++ {
key := fmt.Sprintf("getmany-key-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
t.Run("GetManyExisting", func(t *testing.T) {
keys := make([]string, 10)
for i := 0; i < 10; i++ {
keys[i] = fmt.Sprintf("getmany-key-%d", i)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 10)
for i, key := range keys {
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key])
}
})
t.Run("GetManyMixed", func(t *testing.T) {
keys := []string{
"getmany-key-0", // exists
"nonexistent-key-1", // doesn't exist
"getmany-key-2", // exists
"nonexistent-key-2", // doesn't exist
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2) // Only existing keys
assert.Equal(t, []byte("value-0"), results["getmany-key-0"])
assert.Equal(t, []byte("value-2"), results["getmany-key-2"])
assert.NotContains(t, results, "nonexistent-key-1")
assert.NotContains(t, results, "nonexistent-key-2")
})
t.Run("GetManyEmpty", func(t *testing.T) {
results, err := backend.GetMany(ctx, []string{})
require.NoError(t, err)
assert.NotNil(t, results)
assert.Len(t, results, 0)
})
t.Run("GetManySingleKey", func(t *testing.T) {
results, err := backend.GetMany(ctx, []string{"getmany-key-5"})
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, []byte("value-5"), results["getmany-key-5"])
})
t.Run("GetManyAllNonexistent", func(t *testing.T) {
keys := []string{
"nonexistent-1",
"nonexistent-2",
"nonexistent-3",
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 0)
})
}
// TestPipeline_LargeBatch tests pipelining with large batches
func TestPipeline_LargeBatch(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
t.Run("SetMany100Items", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 100; i++ {
items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Verify random samples
for _, i := range []int{0, 25, 50, 75, 99} {
key := fmt.Sprintf("large-batch-%d", i)
value, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value)
}
})
t.Run("GetMany100Items", func(t *testing.T) {
keys := make([]string, 100)
for i := 0; i < 100; i++ {
keys[i] = fmt.Sprintf("large-batch-%d", i)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 100)
})
}
// TestPipeline_Stats tests that stats are tracked correctly with pipelining
func TestPipeline_Stats(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
// Set some items
items := map[string][]byte{
"stats-key-1": []byte("value1"),
"stats-key-2": []byte("value2"),
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Get items (some exist, some don't)
keys := []string{
"stats-key-1",
"stats-key-2",
"stats-key-nonexistent",
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2)
// Check stats
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Equal(t, int64(2), hits, "Should have 2 hits")
assert.Equal(t, int64(1), misses, "Should have 1 miss")
}
// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining
func BenchmarkPipeline_SetMany(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Prepare items
items := make(map[string][]byte)
for i := 0; i < 100; i++ {
items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = backend.SetMany(ctx, items, time.Minute)
}
}
// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining
func BenchmarkPipeline_GetMany(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 100; i++ {
key := fmt.Sprintf("bench-key-%d", i)
value := []byte(fmt.Sprintf("bench-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
}
// Prepare keys
keys := make([]string, 100)
for i := 0; i < 100; i++ {
keys[i] = fmt.Sprintf("bench-key-%d", i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = backend.GetMany(ctx, keys)
}
}
// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations
func BenchmarkPipeline_VsSequential(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Prepare items
items := make(map[string][]byte)
keys := make([]string, 50)
for i := 0; i < 50; i++ {
key := fmt.Sprintf("compare-key-%d", i)
keys[i] = key
items[key] = []byte(fmt.Sprintf("compare-value-%d", i))
}
b.Run("Pipelined-Set", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = backend.SetMany(ctx, items, time.Minute)
}
})
b.Run("Sequential-Set", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for key, value := range items {
_ = backend.Set(ctx, key, value, time.Minute)
}
}
})
// Pre-populate for get benchmarks
_ = backend.SetMany(ctx, items, time.Hour)
b.Run("Pipelined-Get", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = backend.GetMany(ctx, keys)
}
})
b.Run("Sequential-Get", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, key := range keys {
_, _, _, _ = backend.Get(ctx, key)
}
}
})
}
+142 -3
View File
@@ -2,6 +2,7 @@ package backends
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
@@ -31,6 +32,7 @@ type ConnectionPool struct {
type PoolConfig struct {
Address string
Password string
TLSServerName string // SNI server name; defaults to host(Address) when empty
DB int
MaxConnections int
ConnectTimeout time.Duration
@@ -39,6 +41,8 @@ type PoolConfig struct {
EnableHealthCheck bool // Enable connection health validation
MaxRetries int // Max retries for failed operations
RetryDelay time.Duration // Initial delay between retries
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
}
// NewConnectionPool creates a new connection pool
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
conn, err = p.createConnection(ctx)
if err != nil {
// If this is the last attempt, return error
if attempt == maxAttempts-1 {
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
}
// createConnection creates a new Redis connection
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
// Connect with timeout
dialer := &net.Dialer{
Timeout: p.config.ConnectTimeout,
}
conn, err := dialer.Dial("tcp", p.config.Address)
var conn net.Conn
var err error
if p.config.EnableTLS {
serverName := p.config.TLSServerName
if serverName == "" {
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
serverName = host
}
}
tlsCfg := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
MinVersion: tls.VersionTLS12,
}
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
} else {
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
@@ -336,3 +358,120 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
_, err := conn.Do("PING")
return err == nil
}
// Pipeline represents a Redis pipeline for batch operations
// It queues multiple commands and executes them in a single round-trip
type Pipeline struct {
conn *RedisConn
commands []pipelineCommand
mu sync.Mutex
}
// pipelineCommand represents a single command in the pipeline
type pipelineCommand struct {
command string
args []string
}
// NewPipeline creates a new pipeline for the connection
func (c *RedisConn) NewPipeline() *Pipeline {
return &Pipeline{
conn: c,
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
}
}
// Queue adds a command to the pipeline
func (p *Pipeline) Queue(command string, args ...string) {
p.mu.Lock()
defer p.mu.Unlock()
p.commands = append(p.commands, pipelineCommand{
command: command,
args: args,
})
}
// Execute sends all queued commands and returns all responses
// Returns a slice of responses in the same order as commands were queued
func (p *Pipeline) Execute() ([]interface{}, error) {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.commands) == 0 {
return nil, nil
}
if p.conn.closed.Load() {
return nil, ErrBackendClosed
}
p.conn.mu.Lock()
defer p.conn.mu.Unlock()
// Set write timeout for all commands
if p.conn.writeTimeout > 0 {
// Use longer timeout for batch operations
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
if timeout > 30*time.Second {
timeout = 30 * time.Second // Cap at 30 seconds
}
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
}
// Write all commands (pipelining - send all before reading any responses)
writer := NewRESPWriter(p.conn.conn)
for _, cmd := range p.commands {
cmdArgs := append([]string{cmd.command}, cmd.args...)
if err := writer.WriteCommand(cmdArgs...); err != nil {
writer.Release()
p.conn.closed.Store(true)
return nil, fmt.Errorf("pipeline write error: %w", err)
}
}
writer.Release()
// Set read timeout for all responses
if p.conn.readTimeout > 0 {
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
if timeout > 30*time.Second {
timeout = 30 * time.Second
}
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
}
// Read all responses
responses := make([]interface{}, len(p.commands))
reader := NewRESPReader(p.conn.conn)
defer reader.Release()
for i := range p.commands {
resp, err := reader.ReadResponse()
if err != nil {
// For nil responses, store nil instead of erroring
if errors.Is(err, ErrNilResponse) {
responses[i] = nil
continue
}
p.conn.closed.Store(true)
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
}
responses[i] = resp
}
return responses, nil
}
// Clear resets the pipeline for reuse
func (p *Pipeline) Clear() {
p.mu.Lock()
defer p.mu.Unlock()
p.commands = p.commands[:0]
}
// Len returns the number of queued commands
func (p *Pipeline) Len() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.commands)
}
+31 -1
View File
@@ -3,6 +3,7 @@ package backends
import (
"context"
"errors"
"strings"
"sync"
"testing"
"time"
@@ -201,7 +202,7 @@ func TestConnectionPool_ContextCancellation(t *testing.T) {
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Try to get another with cancelled context
// Try to get another with canceled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
@@ -617,4 +618,33 @@ func TestRedisConn_TooManyArguments(t *testing.T) {
assert.NotContains(t, err.Error(), "too many arguments")
}
})
}
// TestRedisConn_RejectOversizedArgumentBytes is a regression test for CodeQL
// alert #10 (go/allocation-size-overflow). A single argument larger than
// maxTotalArgBytes (64 MiB) must be rejected by the per-argument overflow
// guard in Do() before any allocation is attempted.
func TestRedisConn_RejectOversizedArgumentBytes(t *testing.T) {
mr := NewMiniredisServer(t)
pool, err := NewConnectionPool(&PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
largeArg := strings.Repeat("x", (64<<20)+1)
_, err = conn.Do("SET", "k", largeArg)
require.Error(t, err)
assert.Contains(t, err.Error(), "arguments too large")
}
+230
View File
@@ -0,0 +1,230 @@
package backends
import (
"bufio"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// drainRESPRequest consumes a single RESP request (array or inline) from r and
// returns true on success. Any read error returns false.
func drainRESPRequest(r *bufio.Reader) bool {
header, err := r.ReadString('\n')
if err != nil {
return false
}
if !strings.HasPrefix(header, "*") {
return true // inline command (single line) — already consumed
}
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
if err != nil || n <= 0 {
return false
}
for i := 0; i < n; i++ {
// Each bulk: "$len\r\n<bytes>\r\n"
if _, err := r.ReadString('\n'); err != nil {
return false
}
if _, err := r.ReadString('\n'); err != nil {
return false
}
}
return true
}
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
// answer PING with +PONG. Returns the listener address and a self-signed cert.
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
require.NoError(t, err)
tlsCert := tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
}
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{tlsCert},
MinVersion: tls.VersionTLS12,
})
require.NoError(t, err)
var wg sync.WaitGroup
stopCh := make(chan struct{})
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
}
c, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
defer conn.Close()
reader := bufio.NewReader(conn)
for {
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if !drainRESPRequest(reader) {
return
}
_, _ = conn.Write([]byte("+PONG\r\n"))
}
}(c)
}
}()
stop = func() {
close(stopCh)
_ = listener.Close()
wg.Wait()
}
return listener.Addr().String(), der, stop
}
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
// Regression test for issue #133 (enableTLS not propagated to client).
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
addr, _, stop := startTLSPingServer(t)
defer stop()
pool, err := NewConnectionPool(&PoolConfig{
Address: addr,
MaxConnections: 2,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
defer pool.Put(conn)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
}
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
// TLSSkipVerify=false rejects a self-signed server cert.
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
addr, _, stop := startTLSPingServer(t)
defer stop()
pool, err := NewConnectionPool(&PoolConfig{
Address: addr,
MaxConnections: 2,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: false,
})
require.NoError(t, err)
defer pool.Close()
_, err = pool.Get(context.Background())
require.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "tls")
}
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
// fails to handshake against a plain (non-TLS) listener.
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
plain, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer plain.Close()
go func() {
for {
c, acceptErr := plain.Accept()
if acceptErr != nil {
return
}
_ = c.Close()
}
}()
pool, err := NewConnectionPool(&PoolConfig{
Address: plain.Addr().String(),
MaxConnections: 1,
ConnectTimeout: 1 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
defer pool.Close()
_, err = pool.Get(context.Background())
require.Error(t, err)
}
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
// when EnableTLS=false (default).
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
mr := NewMiniredisServer(t)
pool, err := NewConnectionPool(&PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: false,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
}
+15 -34
View File
@@ -7,52 +7,34 @@ import (
"io"
"strconv"
"strings"
"sync"
)
// RESP (REdis Serialization Protocol) implementation
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
//
// NOTE: sync.Pool was intentionally removed for Yaegi compatibility.
// Yaegi (Traefik's Go interpreter) has issues with sync.Pool and reflection
// that cause "reflect: call of reflect.Value.Field on zero Value" panics.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/120
var (
ErrInvalidRESP = errors.New("invalid RESP response")
ErrNilResponse = errors.New("nil response")
)
// Object pools for memory optimization - reduces allocations by 50-70%
var (
readerPool = sync.Pool{
New: func() interface{} {
return &RESPReader{
r: bufio.NewReaderSize(nil, 4096),
}
},
}
writerPool = sync.Pool{
New: func() interface{} {
return &RESPWriter{
w: nil,
}
},
}
)
// RESPWriter writes RESP protocol messages
type RESPWriter struct {
w io.Writer
}
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
// NewRESPWriter creates a new RESP writer
func NewRESPWriter(w io.Writer) *RESPWriter {
writer := writerPool.Get().(*RESPWriter)
writer.w = w
return writer
return &RESPWriter{w: w}
}
// Release returns the writer to the pool for reuse
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
func (w *RESPWriter) Release() {
w.w = nil
writerPool.Put(w)
// No-op: pooling removed for Yaegi compatibility
}
// WriteCommand writes a Redis command in RESP array format
@@ -78,17 +60,16 @@ type RESPReader struct {
r *bufio.Reader
}
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
// NewRESPReader creates a new RESP reader
func NewRESPReader(r io.Reader) *RESPReader {
reader := readerPool.Get().(*RESPReader)
reader.r.Reset(r)
return reader
return &RESPReader{
r: bufio.NewReaderSize(r, 4096),
}
}
// Release returns the reader to the pool for reuse
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
func (r *RESPReader) Release() {
r.r.Reset(nil)
readerPool.Put(r)
// No-op: pooling removed for Yaegi compatibility
}
// ReadResponse reads a RESP response and returns the parsed value
+183
View File
@@ -0,0 +1,183 @@
package backends
import (
"context"
"sync"
"sync/atomic"
"time"
)
// SingleflightCache wraps a CacheBackend with singleflight deduplication
// to prevent thundering herd problems when multiple concurrent requests
// try to fetch the same uncached key.
type SingleflightCache struct {
backend CacheBackend
mu sync.Mutex
calls map[string]*singleflightCall
// Metrics
deduplicatedCalls atomic.Int64
totalCalls atomic.Int64
}
// singleflightCall represents an in-flight or completed fetch call
type singleflightCall struct {
wg sync.WaitGroup
val []byte
ttl time.Duration
err error
done bool
}
// NewSingleflightCache creates a new singleflight-wrapped cache backend
func NewSingleflightCache(backend CacheBackend) *SingleflightCache {
return &SingleflightCache{
backend: backend,
calls: make(map[string]*singleflightCall),
}
}
// Fetcher is a function type that fetches data when cache misses
type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error)
// GetOrFetch retrieves a value from cache or calls the fetcher exactly once
// per key when there's a cache miss. Concurrent calls for the same key will
// wait for the first call to complete and share its result.
func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) {
s.totalCalls.Add(1)
// Try cache first
value, _, exists, err := s.backend.Get(ctx, key)
if err != nil {
return nil, err
}
if exists {
return value, nil
}
// Cache miss - use singleflight
s.mu.Lock()
// Check if there's already an in-flight call for this key
if call, ok := s.calls[key]; ok {
s.mu.Unlock()
s.deduplicatedCalls.Add(1)
// Wait for the in-flight call to complete
call.wg.Wait()
// Check context cancellation
if ctx.Err() != nil {
return nil, ctx.Err()
}
return call.val, call.err
}
// Create new call
call := &singleflightCall{}
call.wg.Add(1)
s.calls[key] = call
s.mu.Unlock()
// Execute the fetcher
call.val, call.ttl, call.err = fetcher(ctx)
call.done = true
// If successful, store in cache
if call.err == nil && call.val != nil {
// Use a background context for cache storage to ensure it completes
// even if the original context is canceled
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
cancel()
}
// Signal waiting goroutines
call.wg.Done()
// Clean up the call from the map after a short delay
// This allows late arrivals to still benefit from the result
go func() {
time.Sleep(100 * time.Millisecond)
s.mu.Lock()
if c, ok := s.calls[key]; ok && c == call {
delete(s.calls, key)
}
s.mu.Unlock()
}()
return call.val, call.err
}
// Get retrieves a value from the underlying cache backend
func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
return s.backend.Get(ctx, key)
}
// Set stores a value in the underlying cache backend
func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return s.backend.Set(ctx, key, value, ttl)
}
// Delete removes a key from the underlying cache backend
func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) {
return s.backend.Delete(ctx, key)
}
// Exists checks if a key exists in the underlying cache backend
func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) {
return s.backend.Exists(ctx, key)
}
// Clear removes all keys from the underlying cache backend
func (s *SingleflightCache) Clear(ctx context.Context) error {
return s.backend.Clear(ctx)
}
// GetStats returns cache statistics including singleflight metrics
func (s *SingleflightCache) GetStats() map[string]interface{} {
stats := s.backend.GetStats()
// Add singleflight-specific stats
totalCalls := s.totalCalls.Load()
deduped := s.deduplicatedCalls.Load()
stats["singleflight_total_calls"] = totalCalls
stats["singleflight_deduplicated"] = deduped
if totalCalls > 0 {
stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls)
} else {
stats["singleflight_dedup_rate"] = float64(0)
}
s.mu.Lock()
stats["singleflight_inflight"] = len(s.calls)
s.mu.Unlock()
return stats
}
// Close shuts down the cache backend
func (s *SingleflightCache) Close() error {
return s.backend.Close()
}
// Ping checks if the backend is healthy
func (s *SingleflightCache) Ping(ctx context.Context) error {
return s.backend.Ping(ctx)
}
// GetBackend returns the underlying cache backend
func (s *SingleflightCache) GetBackend() CacheBackend {
return s.backend
}
// ResetStats resets the singleflight statistics
func (s *SingleflightCache) ResetStats() {
s.totalCalls.Store(0)
s.deduplicatedCalls.Store(0)
}
// Ensure SingleflightCache implements CacheBackend
var _ CacheBackend = (*SingleflightCache)(nil)
+510
View File
@@ -0,0 +1,510 @@
package backends
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality
func TestSingleflightCache_BasicGetOrFetch(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
t.Run("CacheHit", func(t *testing.T) {
key := "existing-key"
value := []byte("existing-value")
// Pre-populate cache
err := cache.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
var fetchCalled bool
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCalled = true
return []byte("fetched-value"), time.Minute, nil
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
require.NoError(t, err)
assert.Equal(t, value, result)
assert.False(t, fetchCalled, "Fetcher should not be called on cache hit")
})
t.Run("CacheMiss", func(t *testing.T) {
key := "missing-key"
expectedValue := []byte("fetched-value")
var fetchCalled bool
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCalled = true
return expectedValue, time.Minute, nil
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
require.NoError(t, err)
assert.Equal(t, expectedValue, result)
assert.True(t, fetchCalled, "Fetcher should be called on cache miss")
// Verify value was stored in cache
cached, _, exists, err := cache.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, expectedValue, cached)
})
t.Run("FetcherError", func(t *testing.T) {
key := "error-key"
expectedErr := errors.New("fetch failed")
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return nil, 0, expectedErr
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
assert.Error(t, err)
assert.Equal(t, expectedErr, err)
assert.Nil(t, result)
// Verify nothing was stored in cache
_, _, exists, err := cache.Get(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
}
// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated
func TestSingleflightCache_Deduplication(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
key := "dedup-key"
expectedValue := []byte("dedup-value")
var fetchCount atomic.Int32
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
// Simulate slow fetch
time.Sleep(100 * time.Millisecond)
return expectedValue, time.Minute, nil
}
// Launch multiple concurrent requests
concurrency := 10
var wg sync.WaitGroup
results := make([][]byte, concurrency)
errs := make([]error, concurrency)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
}(i)
}
wg.Wait()
// Verify all requests got the same result
for i := 0; i < concurrency; i++ {
assert.NoError(t, errs[i])
assert.Equal(t, expectedValue, results[i])
}
// Verify fetcher was only called once
assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once")
// Verify deduplication stats
stats := cache.GetStats()
deduped := stats["singleflight_deduplicated"].(int64)
assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls")
}
// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel
func TestSingleflightCache_DifferentKeys(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
var fetchCount atomic.Int32
fetchStarted := make(chan struct{}, 3)
fetchComplete := make(chan struct{})
fetcher := func(key string) Fetcher {
return func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
fetchStarted <- struct{}{}
<-fetchComplete // Wait for signal
return []byte("value-" + key), time.Minute, nil
}
}
// Launch concurrent requests for different keys
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
key := fmt.Sprintf("key-%d", idx)
_, _ = cache.GetOrFetch(ctx, key, fetcher(key))
}(i)
}
// Wait for all fetches to start
for i := 0; i < 3; i++ {
<-fetchStarted
}
// All 3 fetches should be running in parallel
assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel")
// Release all fetches
close(fetchComplete)
wg.Wait()
}
// TestSingleflightCache_ContextCancellation tests context cancellation
func TestSingleflightCache_ContextCancellation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
key := "cancel-key"
fetchStarted := make(chan struct{})
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
close(fetchStarted)
// Simulate slow fetch
time.Sleep(500 * time.Millisecond)
return []byte("value"), time.Minute, nil
}
// Start first request with long timeout
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ctx := context.Background()
_, _ = cache.GetOrFetch(ctx, key, fetcher)
}()
// Wait for fetch to start
<-fetchStarted
// Start second request with short timeout
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = cache.GetOrFetch(ctx, key, fetcher)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
wg.Wait()
}
// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated
func TestSingleflightCache_ErrorPropagation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
key := "error-prop-key"
expectedErr := errors.New("intentional error")
var fetchCount atomic.Int32
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
time.Sleep(50 * time.Millisecond)
return nil, 0, expectedErr
}
// Launch multiple concurrent requests
concurrency := 5
var wg sync.WaitGroup
errs := make([]error, concurrency)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
}(i)
}
wg.Wait()
// Verify all requests got the same error
for i := 0; i < concurrency; i++ {
assert.Error(t, errs[i])
assert.Equal(t, expectedErr, errs[i])
}
// Verify fetcher was only called once
assert.Equal(t, int32(1), fetchCount.Load())
}
// TestSingleflightCache_PassthroughMethods tests that passthrough methods work
func TestSingleflightCache_PassthroughMethods(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
t.Run("Set", func(t *testing.T) {
err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute)
require.NoError(t, err)
val, _, exists, err := cache.Get(ctx, "set-key")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("set-value"), val)
})
t.Run("Get", func(t *testing.T) {
err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute)
require.NoError(t, err)
val, ttl, exists, err := cache.Get(ctx, "get-key")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("get-value"), val)
assert.Greater(t, ttl, time.Duration(0))
})
t.Run("Delete", func(t *testing.T) {
err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute)
require.NoError(t, err)
deleted, err := cache.Delete(ctx, "delete-key")
require.NoError(t, err)
assert.True(t, deleted)
exists, err := cache.Exists(ctx, "delete-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Exists", func(t *testing.T) {
exists, err := cache.Exists(ctx, "nonexistent")
require.NoError(t, err)
assert.False(t, exists)
err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute)
require.NoError(t, err)
exists, err = cache.Exists(ctx, "exists-key")
require.NoError(t, err)
assert.True(t, exists)
})
t.Run("Clear", func(t *testing.T) {
err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute)
require.NoError(t, err)
err = cache.Clear(ctx)
require.NoError(t, err)
exists, err := cache.Exists(ctx, "clear-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Ping", func(t *testing.T) {
err := cache.Ping(ctx)
require.NoError(t, err)
})
}
// TestSingleflightCache_Stats tests statistics tracking
func TestSingleflightCache_Stats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
// Make some calls
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(50 * time.Millisecond)
return []byte("value"), time.Minute, nil
}
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.GetOrFetch(ctx, "stats-key", fetcher)
}()
}
wg.Wait()
stats := cache.GetStats()
// Check singleflight stats exist
assert.Contains(t, stats, "singleflight_total_calls")
assert.Contains(t, stats, "singleflight_deduplicated")
assert.Contains(t, stats, "singleflight_dedup_rate")
assert.Contains(t, stats, "singleflight_inflight")
// Verify values
assert.Equal(t, int64(5), stats["singleflight_total_calls"])
assert.Equal(t, int64(4), stats["singleflight_deduplicated"])
// Also check underlying backend stats are included
assert.Contains(t, stats, "hits")
assert.Contains(t, stats, "misses")
}
// TestSingleflightCache_ResetStats tests stats reset
func TestSingleflightCache_ResetStats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return []byte("value"), time.Minute, nil
}
// Make some calls
_, _ = cache.GetOrFetch(ctx, "key1", fetcher)
_, _ = cache.GetOrFetch(ctx, "key2", fetcher)
stats := cache.GetStats()
assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0))
// Reset stats
cache.ResetStats()
stats = cache.GetStats()
assert.Equal(t, int64(0), stats["singleflight_total_calls"])
assert.Equal(t, int64(0), stats["singleflight_deduplicated"])
}
// TestSingleflightCache_GetBackend tests GetBackend method
func TestSingleflightCache_GetBackend(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
assert.Equal(t, backend, cache.GetBackend())
}
// BenchmarkSingleflightCache_Sequential benchmarks sequential access
func BenchmarkSingleflightCache_Sequential(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := fmt.Sprintf("key-%d", i%100)
_, _ = cache.GetOrFetch(ctx, key, fetcher)
}
}
// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access
func BenchmarkSingleflightCache_Concurrent(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(time.Millisecond) // Simulate slow fetch
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication
_, _ = cache.GetOrFetch(ctx, key, fetcher)
i++
}
})
}
// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario
func BenchmarkSingleflightCache_HighContention(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// All goroutines hit the same key
_, _ = cache.GetOrFetch(ctx, "hot-key", fetcher)
}
})
}
+1 -1
View File
@@ -232,7 +232,7 @@ func (m *Manager) Close() error {
var firstErr error
if err := m.tokenCache.Close(); err != nil && firstErr == nil {
if err := m.tokenCache.Close(); err != nil {
firstErr = err
}
if err := m.metadataCache.Close(); err != nil && firstErr == nil {
+1 -1
View File
@@ -397,7 +397,7 @@ func (wp *WorkerPool) Submit(task func()) error {
}
// worker is the main worker routine
func (wp *WorkerPool) worker(id int) {
func (wp *WorkerPool) worker(_ int) {
defer wp.workerWg.Done()
for {
+155
View File
@@ -0,0 +1,155 @@
package dcrstorage
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)
// FileStore implements Store using file-based storage.
// This is the default storage backend for backward compatibility with existing deployments.
// For distributed environments, consider using RedisStore instead.
type FileStore struct {
basePath string
logger Logger
mu sync.RWMutex
}
// NewFileStore creates a new file-based credentials store.
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
func NewFileStore(basePath string, logger Logger) *FileStore {
if basePath == "" {
basePath = "/tmp/oidc-client-credentials.json"
}
if logger == nil {
logger = NoOpLogger()
}
return &FileStore{
basePath: basePath,
logger: logger,
}
}
// BasePath returns the base path used for storing credentials
func (s *FileStore) BasePath() string {
return s.basePath
}
// GetFilePath returns the file path for storing credentials for a specific provider.
// For multi-tenant scenarios, each provider gets a separate file based on URL hash.
func (s *FileStore) GetFilePath(providerURL string) string {
if providerURL == "" {
return s.basePath
}
// Hash provider URL for filename safety and uniqueness
hash := sha256.Sum256([]byte(providerURL))
hashStr := hex.EncodeToString(hash[:8]) // Use first 8 bytes for shorter filename
ext := filepath.Ext(s.basePath)
base := strings.TrimSuffix(s.basePath, ext)
if ext == "" {
ext = ".json"
}
return fmt.Sprintf("%s-%s%s", base, hashStr, ext)
}
// Save stores the client registration response to a file
func (s *FileStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
if creds == nil {
return fmt.Errorf("credentials cannot be nil")
}
s.mu.Lock()
defer s.mu.Unlock()
filePath := s.GetFilePath(providerURL)
// Ensure parent directory exists
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create credentials directory: %w", err)
}
data, err := json.MarshalIndent(creds, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Write with restrictive permissions (owner read/write only)
if err := os.WriteFile(filePath, data, 0600); err != nil {
return fmt.Errorf("failed to write credentials file: %w", err)
}
s.logger.Debugf("Saved client credentials to %s", filePath)
return nil
}
// Load retrieves stored credentials from a file.
// Returns nil, nil if no credentials file exists (not an error).
func (s *FileStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
filePath := s.GetFilePath(providerURL)
// #nosec G304 -- path is constructed from trusted config values via GetFilePath()
data, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // No credentials file exists - not an error
}
return nil, fmt.Errorf("failed to read credentials file: %w", err)
}
var creds ClientRegistrationResponse
if err := json.Unmarshal(data, &creds); err != nil {
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
}
s.logger.Debugf("Loaded client credentials from %s", filePath)
return &creds, nil
}
// Delete removes the credentials file for a provider
func (s *FileStore) Delete(ctx context.Context, providerURL string) error {
s.mu.Lock()
defer s.mu.Unlock()
filePath := s.GetFilePath(providerURL)
if err := os.Remove(filePath); err != nil {
if os.IsNotExist(err) {
return nil // File doesn't exist, nothing to delete
}
return fmt.Errorf("failed to remove credentials file: %w", err)
}
s.logger.Debugf("Deleted client credentials from %s", filePath)
return nil
}
// Exists checks if credentials exist for a provider
func (s *FileStore) Exists(ctx context.Context, providerURL string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
filePath := s.GetFilePath(providerURL)
_, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("failed to check credentials file: %w", err)
}
return true, nil
}
+161
View File
@@ -0,0 +1,161 @@
package dcrstorage
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"sync"
"time"
)
// Cache defines the interface for cache operations needed by RedisStore.
// This allows the main package to provide a cache implementation without
// creating circular dependencies.
type Cache interface {
// Get retrieves a value from the cache
Get(key string) (any, bool)
// Set stores a value in the cache with a TTL
Set(key string, value any, ttl time.Duration) error
// Delete removes a value from the cache
Delete(key string)
}
// RedisStore implements Store using a Cache-backed storage.
// This storage backend enables sharing DCR credentials across multiple Traefik instances
// in distributed environments (e.g., Kubernetes with multiple ingress pods).
type RedisStore struct {
cache Cache
keyPrefix string
logger Logger
mu sync.RWMutex
}
// NewRedisStore creates a new cache-backed credentials store.
// The cache should be configured with a Redis backend for distributed storage.
// If keyPrefix is empty, defaults to "dcr:creds:"
func NewRedisStore(cache Cache, keyPrefix string, logger Logger) *RedisStore {
if keyPrefix == "" {
keyPrefix = "dcr:creds:"
}
if logger == nil {
logger = NoOpLogger()
}
return &RedisStore{
cache: cache,
keyPrefix: keyPrefix,
logger: logger,
}
}
// makeKey creates a unique cache key for a provider URL.
// Uses SHA256 hash of the provider URL for consistent key generation across nodes.
func (s *RedisStore) makeKey(providerURL string) string {
if providerURL == "" {
return s.keyPrefix + "default"
}
hash := sha256.Sum256([]byte(providerURL))
return s.keyPrefix + hex.EncodeToString(hash[:])
}
// Save stores the client registration response in the cache.
// TTL is calculated based on client_secret_expires_at if available.
func (s *RedisStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
if creds == nil {
return fmt.Errorf("credentials cannot be nil")
}
s.mu.Lock()
defer s.mu.Unlock()
key := s.makeKey(providerURL)
// Calculate TTL based on client_secret_expires_at if available
ttl := 30 * 24 * time.Hour // Default: 30 days
if creds.ClientSecretExpiresAt > 0 {
expiresAt := time.Unix(creds.ClientSecretExpiresAt, 0)
ttl = time.Until(expiresAt)
if ttl < 0 {
return fmt.Errorf("credentials already expired")
}
// Add a small buffer to ensure we don't serve expired credentials
if ttl > time.Minute {
ttl -= time.Minute
}
}
// Serialize credentials to JSON for storage
data, err := json.Marshal(creds)
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Store as string in cache (will be serialized by the cache backend)
if err := s.cache.Set(key, string(data), ttl); err != nil {
return fmt.Errorf("failed to store credentials in cache: %w", err)
}
s.logger.Debugf("Saved client credentials to cache with key %s (TTL: %v)", key, ttl)
return nil
}
// Load retrieves stored credentials from the cache.
// Returns nil, nil if no credentials exist (not an error).
func (s *RedisStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
key := s.makeKey(providerURL)
value, exists := s.cache.Get(key)
if !exists {
return nil, nil // No credentials stored - not an error
}
// Handle different value types from cache
var jsonData string
switch v := value.(type) {
case string:
jsonData = v
case []byte:
jsonData = string(v)
default:
// Try to see if it's already the struct (from local cache)
if creds, ok := value.(*ClientRegistrationResponse); ok {
return creds, nil
}
return nil, fmt.Errorf("unexpected credentials type in cache: %T", value)
}
var creds ClientRegistrationResponse
if err := json.Unmarshal([]byte(jsonData), &creds); err != nil {
return nil, fmt.Errorf("failed to parse credentials from cache: %w", err)
}
s.logger.Debugf("Loaded client credentials from cache with key %s", key)
return &creds, nil
}
// Delete removes stored credentials from the cache
func (s *RedisStore) Delete(ctx context.Context, providerURL string) error {
s.mu.Lock()
defer s.mu.Unlock()
key := s.makeKey(providerURL)
s.cache.Delete(key)
s.logger.Debugf("Deleted client credentials from cache with key %s", key)
return nil
}
// Exists checks if credentials exist in the cache for a provider
func (s *RedisStore) Exists(ctx context.Context, providerURL string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
key := s.makeKey(providerURL)
_, exists := s.cache.Get(key)
return exists, nil
}
+90
View File
@@ -0,0 +1,90 @@
// Package dcrstorage provides storage backends for OIDC Dynamic Client Registration credentials.
// It supports both file-based and Redis-based storage for persisting client credentials
// across application restarts and distributed deployments.
package dcrstorage
import (
"context"
)
// StorageBackend represents the type of storage backend for DCR credentials
type StorageBackend string
const (
// StorageBackendFile uses file-based storage (default for backward compatibility)
StorageBackendFile StorageBackend = "file"
// StorageBackendRedis uses Redis for distributed storage
StorageBackendRedis StorageBackend = "redis"
// StorageBackendAuto automatically selects Redis if available, otherwise file
StorageBackendAuto StorageBackend = "auto"
)
// Logger interface for DCR storage operations
type Logger interface {
Debug(msg string)
Debugf(format string, args ...any)
Info(msg string)
Infof(format string, args ...any)
Error(msg string)
Errorf(format string, args ...any)
}
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
type ClientRegistrationResponse struct {
SubjectType string `json:"subject_type,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
Scope string `json:"scope,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
TOSURI string `json:"tos_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
ClientID string `json:"client_id"`
ClientName string `json:"client_name,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
}
// Store defines the interface for storing DCR credentials.
// This abstraction allows different storage backends (file, Redis) to be used
// for persisting OIDC Dynamic Client Registration credentials across nodes.
type Store interface {
// Save stores the client registration response for a provider
// The providerURL is used as a key to support multi-tenant scenarios
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
// Load retrieves stored credentials for a provider
// Returns nil, nil if no credentials exist (not an error)
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
// Delete removes stored credentials for a provider
Delete(ctx context.Context, providerURL string) error
// Exists checks if credentials exist for a provider
Exists(ctx context.Context, providerURL string) (bool, error)
}
// noOpLogger is a no-op implementation of Logger for default use
type noOpLogger struct{}
func (n noOpLogger) Debug(msg string) {}
func (n noOpLogger) Debugf(format string, args ...any) {}
func (n noOpLogger) Info(msg string) {}
func (n noOpLogger) Infof(format string, args ...any) {}
func (n noOpLogger) Error(msg string) {}
func (n noOpLogger) Errorf(format string, args ...any) {}
// NoOpLogger returns a no-op logger instance
func NoOpLogger() Logger {
return noOpLogger{}
}
+464
View File
@@ -0,0 +1,464 @@
package dcrstorage
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
// mockCache implements Cache for testing
type mockCache struct {
data map[string]cacheEntry
mu sync.RWMutex
}
type cacheEntry struct {
value any
expiresAt time.Time
}
func newMockCache() *mockCache {
return &mockCache{data: make(map[string]cacheEntry)}
}
func (m *mockCache) Get(key string) (any, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
entry, ok := m.data[key]
if !ok {
return nil, false
}
if time.Now().After(entry.expiresAt) {
return nil, false
}
return entry.value, true
}
func (m *mockCache) Set(key string, value any, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = cacheEntry{
value: value,
expiresAt: time.Now().Add(ttl),
}
return nil
}
func (m *mockCache) Delete(key string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
}
func TestFileStore_SaveLoad(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
testCreds := &ClientRegistrationResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "test-access-token",
RegistrationClientURI: "https://example.com/register/test-client-id",
RedirectURIs: []string{"https://app.example.com/callback"},
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "client_secret_basic",
}
ctx := context.Background()
providerURL := "https://auth.example.com"
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
tempDir2 := t.TempDir()
store2 := NewFileStore(filepath.Join(tempDir2, "nonexistent.json"), nil)
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent file: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if exists {
t.Error("Expected credentials to not exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("delete non-existent credentials", func(t *testing.T) {
err := store.Delete(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Delete should not error for non-existent: %v", err)
}
})
}
func TestFileStore_MultiProvider(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
provider1 := "https://auth1.example.com"
provider2 := "https://auth2.example.com"
creds1 := &ClientRegistrationResponse{
ClientID: "client-1",
ClientSecret: "secret-1",
}
creds2 := &ClientRegistrationResponse{
ClientID: "client-2",
ClientSecret: "secret-2",
}
if err := store.Save(ctx, provider1, creds1); err != nil {
t.Fatalf("Failed to save creds1: %v", err)
}
if err := store.Save(ctx, provider2, creds2); err != nil {
t.Fatalf("Failed to save creds2: %v", err)
}
loaded1, err := store.Load(ctx, provider1)
if err != nil {
t.Fatalf("Failed to load creds1: %v", err)
}
if loaded1.ClientID != "client-1" {
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
}
loaded2, err := store.Load(ctx, provider2)
if err != nil {
t.Fatalf("Failed to load creds2: %v", err)
}
if loaded2.ClientID != "client-2" {
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
}
if err := store.Delete(ctx, provider1); err != nil {
t.Fatalf("Failed to delete creds1: %v", err)
}
exists, _ := store.Exists(ctx, provider2)
if !exists {
t.Error("Provider 2 credentials should still exist")
}
}
func TestFileStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
creds := &ClientRegistrationResponse{
ClientID: "test-client",
ClientSecret: "test-secret",
}
var wg sync.WaitGroup
concurrency := 10
for range concurrency {
wg.Add(1)
go func() {
defer wg.Done()
_ = store.Save(ctx, providerURL, creds)
}()
}
wg.Wait()
for range concurrency {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = store.Load(ctx, providerURL)
}()
}
wg.Wait()
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load after concurrent access: %v", err)
}
if loaded == nil || loaded.ClientID != "test-client" {
t.Error("Credentials corrupted after concurrent access")
}
}
func TestFileStore_InvalidInput(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
t.Run("empty provider URL uses default path", func(t *testing.T) {
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "", creds)
if err != nil {
t.Fatalf("Save with empty provider URL failed: %v", err)
}
loaded, err := store.Load(ctx, "")
if err != nil {
t.Fatalf("Load with empty provider URL failed: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials with empty provider URL")
}
})
}
func TestFileStore_DefaultPath(t *testing.T) {
t.Parallel()
store := NewFileStore("", nil)
if store.BasePath() == "" {
t.Error("Expected default base path")
}
}
func TestRedisStore_WithMockCache(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
testCreds := &ClientRegistrationResponse{
ClientID: "redis-test-client",
ClientSecret: "redis-test-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "redis-test-token",
RedirectURIs: []string{"https://app.example.com/callback"},
}
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
}
func TestRedisStore_TTLFromExpiry(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
t.Run("expired credentials should fail", func(t *testing.T) {
expiredCreds := &ClientRegistrationResponse{
ClientID: "expired-client",
ClientSecret: "expired-secret",
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
}
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
if err == nil {
t.Error("Expected error for expired credentials")
}
})
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
creds := &ClientRegistrationResponse{
ClientID: "no-expiry-client",
ClientSecret: "no-expiry-secret",
ClientSecretExpiresAt: 0,
}
err := store.Save(ctx, "https://noexpiry.example.com", creds)
if err != nil {
t.Fatalf("Failed to save credentials without expiry: %v", err)
}
})
}
func TestRedisStore_InvalidInput(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
}
func TestFileStore_CorruptedFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
filePath := store.GetFilePath(providerURL)
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
t.Fatalf("Failed to write corrupted file: %v", err)
}
_, err := store.Load(ctx, providerURL)
if err == nil {
t.Error("Expected error for corrupted JSON")
}
}
func TestFileStore_DirectoryCreation(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
store := NewFileStore(deepPath, nil)
ctx := context.Background()
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "https://example.com", creds)
if err != nil {
t.Fatalf("Failed to save with nested directory: %v", err)
}
loaded, err := store.Load(ctx, "https://example.com")
if err != nil {
t.Fatalf("Failed to load after nested directory creation: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials from nested directory")
}
}
+1 -1
View File
@@ -173,7 +173,7 @@ func (m *FeatureManager) LoadFromEnv() {
for name, flag := range flags {
envVar := "FEATURE_" + name
if value := os.Getenv(envVar); value != "" {
enabled := strings.ToLower(value) == "true" || value == "1"
enabled := strings.EqualFold(value, "true") || value == "1"
flag.enabled.Store(enabled)
}
}
+1 -1
View File
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
var filteredScopes []string
for _, scope := range scopes {
if strings.ToLower(scope) != ScopeOfflineAccess {
if !strings.EqualFold(scope, ScopeOfflineAccess) {
filteredScopes = append(filteredScopes, scope)
}
}
+2 -1
View File
@@ -147,7 +147,8 @@ func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
return p
}
case ProviderTypeKeycloak:
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/auth/realms/") {
// Match both Keycloak <17 (`/auth/realms/`) and 17+ (`/realms/`).
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/realms/") {
return p
}
case ProviderTypeAWSCognito:
+6 -1
View File
@@ -225,10 +225,15 @@ func TestProviderRegistry_DetectProvider(t *testing.T) {
expected: oktaProvider,
},
{
name: "Keycloak provider detection",
name: "Keycloak provider detection (legacy /auth/realms/)",
issuerURL: "https://auth.example.com/auth/realms/master",
expected: keycloakProvider,
},
{
name: "Keycloak provider detection (modern /realms/, KC 17+)",
issuerURL: "https://auth.example.com/realms/master",
expected: keycloakProvider,
},
{
name: "AWS Cognito provider detection",
issuerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example",
+11 -10
View File
@@ -18,16 +18,17 @@ func GetProviderWarnings(providerType ProviderType) []ProviderWarning {
switch providerType {
case ProviderTypeGitHub:
warnings = append(warnings, ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "warning",
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
})
warnings = append(warnings, ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "info",
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
})
warnings = append(warnings,
ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "warning",
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
},
ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "info",
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
})
case ProviderTypeAuth0:
warnings = append(warnings, ProviderWarning{
+4 -3
View File
@@ -116,7 +116,7 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
// Continue to next attempt
case <-ctx.Done():
re.RecordFailure()
return fmt.Errorf("retry cancelled: %w", ctx.Err())
return fmt.Errorf("retry canceled: %w", ctx.Err())
}
}
@@ -301,7 +301,7 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
}
}
allMetrics["summary"] = map[string]interface{}{
summary := map[string]interface{}{
"totalMechanisms": len(rm.mechanisms),
"totalRequests": totalRequests,
"totalSuccesses": totalSuccesses,
@@ -310,8 +310,9 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
if totalRequests > 0 {
successRate := float64(totalSuccesses) / float64(totalRequests) * 100
allMetrics["summary"].(map[string]interface{})["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
summary["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
}
allMetrics["summary"] = summary
return allMetrics
}
+3 -3
View File
@@ -223,7 +223,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelled(t *testing.T) {
wg.Wait()
if execErr == nil {
t.Error("Expected error when context is cancelled")
t.Error("Expected error when context is canceled")
}
}
@@ -240,7 +240,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelledBeforeStart(t *testing
})
if err == nil {
t.Error("Expected error when context is already cancelled")
t.Error("Expected error when context is already canceled")
}
}
@@ -282,7 +282,7 @@ func TestRetryExecutor_isRetryableError(t *testing.T) {
{name: "timeout", err: errors.New("TIMEOUT"), expected: true}, // case insensitive
{name: "EOF", err: errors.New("EOF"), expected: false},
{name: "random error", err: errors.New("something else"), expected: false},
{name: "context cancelled", err: context.Canceled, expected: false},
{name: "context canceled", err: context.Canceled, expected: false},
{name: "context deadline exceeded", err: context.DeadlineExceeded, expected: false},
}
+135
View File
@@ -0,0 +1,135 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
// the fix for issue #132: token refresh path hardcoded the "email" claim and
// ignored the configured userIdentifierClaim. Keycloak users without an email
// claim (using sub or another identifier) were being kicked out on refresh
// even though their initial login worked.
//
// The callback path (auth_flow.go) already honored userIdentifierClaim with
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
// after PR #100 (commit a316a98).
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
tests := []struct {
claims map[string]any
name string
userIdentifierClaim string
expectedIdentifier string
expectSuccess bool
}{
{
name: "sub claim configured, only sub present (Keycloak no-email case)",
userIdentifierClaim: "sub",
claims: map[string]any{
"sub": "user-uuid-keycloak-12345",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "user-uuid-keycloak-12345",
},
{
name: "preferred_username configured, claim present",
userIdentifierClaim: "preferred_username",
claims: map[string]any{
"sub": "user-uuid-12345",
"preferred_username": "alice",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "alice",
},
{
name: "configured claim missing, falls back to sub",
userIdentifierClaim: "preferred_username",
claims: map[string]any{
"sub": "fallback-sub-id",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "fallback-sub-id",
},
{
name: "email default, email present (backward compatibility)",
userIdentifierClaim: "email",
claims: map[string]any{
"sub": "user-uuid-12345",
"email": "user@example.com",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "user@example.com",
},
{
name: "email default, no email and no sub - refresh fails",
userIdentifierClaim: "email",
claims: map[string]any{
"exp": float64(9999999999),
},
expectSuccess: false,
expectedIdentifier: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sessionManager, err := NewSessionManager(
"test-encryption-key-32-bytes-long!!",
false,
"",
"",
0,
NewLogger("error"),
)
if err != nil {
t.Fatalf("session manager: %v", err)
}
defer sessionManager.Shutdown()
capturedClaims := tt.claims
tOidc := &TraefikOidc{
logger: NewLogger("error"),
userIdentifierClaim: tt.userIdentifierClaim,
sessionManager: sessionManager,
tokenExchanger: &EnhancedMockTokenExchanger{
RefreshResponse: &TokenResponse{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
IDToken: "new-id-token-jwt",
ExpiresIn: 3600,
},
},
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
extractClaimsFunc: func(token string) (map[string]any, error) {
return capturedClaims, nil
},
}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("get session: %v", err)
}
defer session.returnToPoolSafely()
session.SetRefreshToken("initial-refresh-token")
refreshed := tOidc.refreshToken(rw, req, session)
if refreshed != tt.expectSuccess {
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
}
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
}
})
}
}
+256
View File
@@ -0,0 +1,256 @@
package traefikoidc
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError reproduces and
// verifies the fix for issue #134.
//
// Symptom (before fix): with a Redis backend wired into UniversalCache,
// caching the parsed *parsedJWKS triggered:
//
// json: cannot unmarshal number 2251513...
// into Go value of type float64
//
// Root cause: under yaegi, json.Marshal of a struct exposes unexported
// fields with an X-prefixed name. parsedJWKS{ keys map[string]crypto.PublicKey }
// thus serialized the inner *rsa.PublicKey, whose modulus *big.Int marshals
// as a JSON number hundreds of digits long. On read, json.Unmarshal into
// interface{} parses numbers as float64, which cannot represent that range.
// The user saw the error log on every request even though auth still worked
// (fallback path rebuilt the keys in memory).
//
// Fix: route both *JWKSet and *parsedJWKS through SetLocal/GetLocal — the
// distributed backend never sees them.
func TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "issue134:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-test-kid"
jwk := JWK{
Kty: "RSA",
Use: "sig",
Alg: "RS256",
Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
}
var fetchCount int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&fetchCount, 1)
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
}))
defer server.Close()
errBuf := &bytes.Buffer{}
infoBuf := &bytes.Buffer{}
logger := &Logger{
logError: log.New(errBuf, "", 0),
logInfo: log.New(infoBuf, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeJWK,
MaxSize: 100,
Logger: logger,
}, backend)
defer cache.Close()
jwkCache := &JWKCache{cache: cache}
ctx := context.Background()
pub1, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
require.NoError(t, err, "first GetPublicKey should succeed")
require.NotNil(t, pub1)
gotRSA, ok := pub1.(*rsa.PublicKey)
require.True(t, ok, "returned key should be *rsa.PublicKey, got %T", pub1)
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N), "modulus must survive intact")
assert.Equal(t, rsaKey.E, gotRSA.E, "exponent must survive intact")
pub2, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
require.NoError(t, err, "second GetPublicKey should succeed")
require.True(t, samePublicKey(pub1, pub2), "second call must return the same parsed key (cache hit)")
assert.Equal(t, int32(1), atomic.LoadInt32(&fetchCount),
"upstream JWKS endpoint must be hit exactly once; second call must be served from local cache")
errOutput := errBuf.String()
assert.NotContains(t, errOutput, "Failed to deserialize",
"deserialize error must not appear with the fix in place; got: %s", errOutput)
assert.NotContains(t, errOutput, "into Go value of type float64",
"float64 unmarshal error must not appear; got: %s", errOutput)
parsedKey := server.URL + parsedKeysSuffix
jwksKey := server.URL
for _, k := range []string{cache.prefixKey(parsedKey), cache.prefixKey(jwksKey)} {
fullKey := redisCfg.RedisPrefix + k
assert.False(t, mr.Exists(fullKey),
"key %q must not exist in Redis (local-only caching); got %v", fullKey, mr.Keys())
}
}
// TestIssue134_StalePoisonedRedisDataIgnored verifies that pre-existing bad
// data left in Redis under a JWK :parsed key from a prior buggy version is
// ignored: the local-only fix never reads that key, so no log spam, and the
// fallback path returns a real *rsa.PublicKey.
func TestIssue134_StalePoisonedRedisDataIgnored(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "issue134stale:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-test-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
}))
defer server.Close()
// Pre-poison Redis with the kind of payload the old buggy path would have
// produced (huge unquoted JSON number for the modulus). With the fix the
// JWKCache must not even read this key.
poisoned := []byte("\x01" + strings.Replace(
`{"Xkeys":{"azure-test-kid":{"N":NUMBER,"E":65537}}}`,
"NUMBER", rsaKey.N.String(), 1,
))
parsedRedisKey := redisCfg.RedisPrefix + "jwk:" + server.URL + parsedKeysSuffix
require.NoError(t, mr.Set(parsedRedisKey, string(poisoned)))
errBuf := &bytes.Buffer{}
logger := &Logger{
logError: log.New(errBuf, "", 0),
logInfo: log.New(io.Discard, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeJWK,
MaxSize: 100,
Logger: logger,
}, backend)
defer cache.Close()
jwkCache := &JWKCache{cache: cache}
pub, err := jwkCache.GetPublicKey(context.Background(), server.URL, kid, http.DefaultClient)
require.NoError(t, err)
require.NotNil(t, pub)
gotRSA, ok := pub.(*rsa.PublicKey)
require.True(t, ok)
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N))
assert.NotContains(t, errBuf.String(), "Failed to deserialize",
"poisoned Redis entry must not be touched; got error log: %s", errBuf.String())
}
// TestIssue134_SetLocalGetLocalSkipBackend verifies the new SetLocal/GetLocal
// pair never reads or writes the configured backend.
func TestIssue134_SetLocalGetLocalSkipBackend(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "local:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 10,
Logger: GetSingletonNoOpLogger(),
}, backend)
defer cache.Close()
type unsafeShape struct {
hidden map[string]interface{}
}
val := &unsafeShape{hidden: map[string]interface{}{"k": 1}}
require.NoError(t, cache.SetLocal("local-key", val, 1*time.Hour))
got, found := cache.GetLocal("local-key")
require.True(t, found)
assert.Same(t, val, got, "GetLocal must return the exact pointer stored, no JSON round-trip")
for _, k := range mr.Keys() {
assert.NotContains(t, k, "local-key",
"SetLocal must not write to Redis; found key %q (all keys: %v)", k, mr.Keys())
}
cache.mu.Lock()
delete(cache.items, "local-key")
cache.lruList.Init()
cache.currentSize = 0
cache.currentMemory = 0
cache.mu.Unlock()
_, found = cache.GetLocal("local-key")
assert.False(t, found, "GetLocal must not fall back to backend after local cache cleared")
}
// big2bytes returns the big-endian byte slice for a positive int.
func big2bytes(e int) []byte {
if e <= 0 {
return []byte{}
}
var buf []byte
for e > 0 {
buf = append([]byte{byte(e & 0xff)}, buf...)
e >>= 8
}
return buf
}
// samePublicKey reports whether two crypto.PublicKey instances represent the
// same RSA key, used to confirm cache hits return identical reconstructed
// keys.
func samePublicKey(a, b interface{}) bool {
ar, ok1 := a.(*rsa.PublicKey)
br, ok2 := b.(*rsa.PublicKey)
if !ok1 || !ok2 {
return false
}
return ar.N.Cmp(br.N) == 0 && ar.E == br.E
}
+90 -6
View File
@@ -2,6 +2,7 @@ package traefikoidc
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
@@ -18,6 +19,18 @@ import (
"time"
)
// parsedKeysSuffix marks the parallel UniversalCache entry that stores
// pre-parsed public keys for a given JWKS URL.
const parsedKeysSuffix = ":parsed"
// parsedJWKS holds keys decoded from a JWKSet, indexed by kid. Storing the
// already-parsed crypto.PublicKey avoids re-running the DER/PEM round trip
// on every JWT verification — a costly operation under the yaegi interpreter
// that hosts Traefik plugins.
type parsedJWKS struct {
keys map[string]crypto.PublicKey
}
// JWK represents a JSON Web Key as defined in RFC 7517.
// It can represent different key types including RSA, EC, and symmetric keys.
type JWK struct {
@@ -49,6 +62,7 @@ type JWKCache struct {
// JWKCacheInterface defines the contract for JWK caching implementations.
type JWKCacheInterface interface {
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error)
Cleanup()
Close()
}
@@ -62,9 +76,15 @@ func NewJWKCache() *JWKCache {
}
// GetJWKS retrieves JWKS from cache or fetches from the remote URL if not cached.
//
// The entry is stored locally only via SetLocal/GetLocal. Going through a
// distributed backend defeats the cache: JSON round-tripping turns *JWKSet
// into map[string]interface{}, the type assertion below fails, and every
// request refetches from the upstream. JWK rotation is rare and a per-replica
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
// Check cache first
if cachedValue, found := c.cache.Get(jwksURL); found {
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
if jwks, ok := cachedValue.(*JWKSet); ok {
return jwks, nil
}
@@ -74,7 +94,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
defer c.mutex.Unlock()
// Double-check after acquiring lock
if cachedValue, found := c.cache.Get(jwksURL); found {
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
if jwks, ok := cachedValue.(*JWKSet); ok {
return jwks, nil
}
@@ -91,11 +111,75 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
}
// Cache for 1 hour
_ = c.cache.Set(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
return jwks, nil
}
// GetPublicKey returns the parsed public key for a given kid, fetching and
// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is
// stored alongside the raw JWKSet under a sibling cache key with the same
// 1-hour TTL, so both invalidate together when the upstream JWKS rotates.
//
// parsedJWKS is stored locally only (SetLocal/GetLocal). Its values are
// crypto.PublicKey interfaces wrapping *rsa.PublicKey/*ecdsa.PublicKey,
// which contain *big.Int that marshals to a hundreds-digit JSON number.
// On a distributed backend round-trip, json.Unmarshal into interface{} would
// try to fit that into float64 and fail with UnmarshalTypeError. Under yaegi
// the unexported parsedJWKS.keys field is exposed via an X-prefixed name on
// Marshal, leaking the modulus into the cached payload (issue #134).
func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
parsedKey := jwksURL + parsedKeysSuffix
if v, found := c.cache.GetLocal(parsedKey); found {
if pj, ok := v.(*parsedJWKS); ok {
if k, ok := pj.keys[kid]; ok {
return k, nil
}
}
}
jwks, err := c.GetJWKS(ctx, jwksURL, httpClient)
if err != nil {
return nil, err
}
pj := buildParsedJWKS(jwks)
_ = c.cache.SetLocal(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
if k, ok := pj.keys[kid]; ok {
return k, nil
}
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
}
// buildParsedJWKS pre-parses every JWK in the set into the matching
// crypto.PublicKey, indexed by kid. Errors on individual keys are skipped so
// a single bad key does not block the rest of the keyset.
func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
out := make(map[string]crypto.PublicKey, len(jwks.Keys))
for i := range jwks.Keys {
k := &jwks.Keys[i]
if k.Kid == "" {
continue
}
var pub crypto.PublicKey
var err error
switch k.Kty {
case "RSA":
pub, err = k.ToRSAPublicKey()
case "EC":
pub, err = k.ToECDSAPublicKey()
default:
continue
}
if err != nil {
continue
}
out[k.Kid] = pub
}
return &parsedJWKS{keys: out}
}
// Cleanup is a no-op as cleanup is handled by UniversalCache
func (c *JWKCache) Cleanup() {
// Handled internally by UniversalCache
@@ -213,9 +297,9 @@ func (jwk *JWK) ToECDSAPublicKey() (*ecdsa.PublicKey, error) {
// GetKey finds a key by its ID (kid) in the JWKSet.
// Returns nil if no key with the given ID is found.
func (jwks *JWKSet) GetKey(kid string) *JWK {
for _, key := range jwks.Keys {
if key.Kid == kid {
return &key
for i := range jwks.Keys {
if jwks.Keys[i].Kid == kid {
return &jwks.Keys[i]
}
}
return nil
+16 -9
View File
@@ -120,7 +120,7 @@ func getReplayCacheStats() (size int, maxSize int) {
// Parameters:
// - ctx: Parent context for cancellation
// - logger: Logger for debug output (can be nil)
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
func startReplayCacheCleanup(_ context.Context, logger *Logger) {
registry := GetGlobalTaskRegistry()
// Define the cleanup task function
@@ -528,6 +528,21 @@ func verifyNotBefore(notBefore float64) error {
// - An error if the key parsing fails, the algorithm is unsupported,
// or the signature verification fails
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
return verifySignatureWithKey(tokenString, pubKey, alg)
}
// verifySignatureWithKey verifies a JWT signature using an already-parsed
// public key, skipping the PEM-encode/decode round trip that verifySignature
// performs. This is the hot path used by VerifyJWTSignatureAndClaims.
func verifySignatureWithKey(tokenString string, pubKey crypto.PublicKey, alg string) error {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
@@ -537,14 +552,6 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
+502
View File
@@ -0,0 +1,502 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik.
// This file implements OIDC Backchannel Logout (OpenID Connect Back-Channel Logout 1.0)
// and Front-Channel Logout (OpenID Connect Front-Channel Logout 1.0) functionality.
package traefikoidc
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
const (
// logoutTokenType is the expected typ claim for logout tokens
// #nosec G101 -- This is a JWT type claim value from OIDC spec, not a credential
logoutTokenType = "logout+jwt"
// sessionInvalidationTTL is how long to remember invalidated sessions
// Should be at least as long as your session max age
sessionInvalidationTTL = 25 * time.Hour
)
// LogoutTokenClaims represents the claims in an OIDC logout token
// as defined in OpenID Connect Back-Channel Logout 1.0
type LogoutTokenClaims struct {
Issuer string `json:"iss"`
Subject string `json:"sub,omitempty"`
Audience interface{} `json:"aud"` // Can be string or []string
IssuedAt int64 `json:"iat"`
JTI string `json:"jti"`
Events map[string]interface{} `json:"events"`
SessionID string `json:"sid,omitempty"`
Nonce string `json:"nonce,omitempty"` // Must NOT be present
}
// handleBackchannelLogout processes OIDC Backchannel Logout requests.
// It accepts POST requests with a logout_token parameter containing a JWT
// that identifies which session(s) to terminate.
//
// According to OpenID Connect Back-Channel Logout 1.0:
// - The logout_token is a JWT signed by the IdP
// - It contains either a 'sid' (session ID) or 'sub' (subject) claim to identify the session
// - The RP must validate the token and invalidate the matching session(s)
//
// Parameters:
// - rw: The HTTP response writer
// - req: The HTTP request containing the logout_token
func (t *TraefikOidc) handleBackchannelLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing backchannel logout request")
// Backchannel logout must be POST
if req.Method != http.MethodPost {
t.logger.Errorf("Backchannel logout: invalid method %s, expected POST", req.Method)
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse form data to get logout_token
if err := req.ParseForm(); err != nil {
t.logger.Errorf("Backchannel logout: failed to parse form: %v", err)
http.Error(rw, "Bad request", http.StatusBadRequest)
return
}
logoutToken := req.FormValue("logout_token")
if logoutToken == "" {
// Also try reading from request body as raw JWT
body, err := io.ReadAll(io.LimitReader(req.Body, 64*1024)) // 64KB limit
if err == nil && len(body) > 0 {
logoutToken = string(body)
}
}
if logoutToken == "" {
t.logger.Error("Backchannel logout: missing logout_token")
http.Error(rw, "logout_token required", http.StatusBadRequest)
return
}
// Parse and validate the logout token
claims, err := t.validateLogoutToken(logoutToken)
if err != nil {
t.logger.Errorf("Backchannel logout: token validation failed: %v", err)
// Return 400 for invalid token per spec
http.Error(rw, "Invalid logout token", http.StatusBadRequest)
return
}
// Invalidate session(s) based on sid or sub
if err := t.invalidateSession(claims.SessionID, claims.Subject); err != nil {
t.logger.Errorf("Backchannel logout: failed to invalidate session: %v", err)
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
return
}
t.logger.Infof("Backchannel logout: successfully invalidated session (sid=%s, sub=%s)",
claims.SessionID, claims.Subject)
// Return 200 OK with empty body per spec
rw.WriteHeader(http.StatusOK)
}
// handleFrontchannelLogout processes OIDC Front-Channel Logout requests.
// It accepts GET requests with 'iss' and 'sid' query parameters that identify
// which session to terminate. The IdP typically loads this URL in an iframe.
//
// According to OpenID Connect Front-Channel Logout 1.0:
// - The request contains 'iss' (issuer) and optionally 'sid' (session ID)
// - The RP should clear the session and return a response (typically empty or image)
// - The response must be cacheable to allow the IdP to load it in an iframe
//
// Parameters:
// - rw: The HTTP response writer
// - req: The HTTP request containing iss and sid parameters
func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing front-channel logout request")
// Front-channel logout should be GET
if req.Method != http.MethodGet {
t.logger.Errorf("Front-channel logout: invalid method %s, expected GET", req.Method)
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get iss and sid from query parameters
iss := req.URL.Query().Get("iss")
sid := req.URL.Query().Get("sid")
// Validate issuer matches our expected issuer
t.metadataMu.RLock()
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
if iss != "" && iss != expectedIssuer {
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
return
}
// Must have at least sid for front-channel logout
if sid == "" {
t.logger.Error("Front-channel logout: missing sid parameter")
http.Error(rw, "sid parameter required", http.StatusBadRequest)
return
}
// Invalidate the session
if err := t.invalidateSession(sid, ""); err != nil {
t.logger.Errorf("Front-channel logout: failed to invalidate session: %v", err)
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
return
}
t.logger.Infof("Front-channel logout: successfully invalidated session (sid=%s)", sid)
// Return a minimal HTML response that's suitable for iframe loading
// Set headers to allow embedding and caching
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Cache-Control", "no-cache, no-store")
rw.Header().Set("Pragma", "no-cache")
// Allow embedding in iframes from any origin (required for front-channel logout)
rw.Header().Del("X-Frame-Options")
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("<!DOCTYPE html><html><head><title>Logged Out</title></head><body></body></html>"))
}
// validateLogoutToken parses and validates a logout token JWT.
// It verifies the token signature, issuer, audience, and required claims.
//
// Parameters:
// - tokenString: The raw JWT logout token
//
// Returns:
// - The parsed logout token claims
// - An error if validation fails
func (t *TraefikOidc) validateLogoutToken(tokenString string) (*LogoutTokenClaims, error) {
// Parse the JWT
jwt, err := parseJWT(tokenString)
if err != nil {
return nil, fmt.Errorf("failed to parse logout token: %w", err)
}
// Check token type if present
if typ, ok := jwt.Header["typ"].(string); ok {
// The typ should be "logout+jwt" or omitted
if typ != "" && typ != logoutTokenType && typ != "JWT" {
return nil, fmt.Errorf("invalid token type: %s", typ)
}
}
// Verify signature only (not standard claims - logout tokens don't have 'exp')
if err := t.verifyLogoutTokenSignature(jwt, tokenString); err != nil {
return nil, fmt.Errorf("signature verification failed: %w", err)
}
// Extract claims
claims := &LogoutTokenClaims{}
claimsJSON, err := json.Marshal(jwt.Claims)
if err != nil {
return nil, fmt.Errorf("failed to marshal claims: %w", err)
}
if err := json.Unmarshal(claimsJSON, claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
// Validate required claims
t.metadataMu.RLock()
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
// Validate issuer
if claims.Issuer != expectedIssuer {
return nil, fmt.Errorf("issuer mismatch: got %s, expected %s", claims.Issuer, expectedIssuer)
}
// Validate audience (must contain our client_id)
if !t.validateLogoutTokenAudience(claims.Audience) {
return nil, fmt.Errorf("audience validation failed")
}
// Validate iat (issued at) - must be present and not too old
if claims.IssuedAt == 0 {
return nil, fmt.Errorf("missing iat claim")
}
iatTime := time.Unix(claims.IssuedAt, 0)
// Allow up to 5 minutes clock skew and 10 minutes token age
if time.Since(iatTime) > 15*time.Minute {
return nil, fmt.Errorf("logout token too old: issued at %v", iatTime)
}
// Token should not be from the future (with 5 min clock skew tolerance)
if iatTime.After(time.Now().Add(5 * time.Minute)) {
return nil, fmt.Errorf("logout token issued in the future: %v", iatTime)
}
// Validate events claim - must contain the logout event
if claims.Events == nil {
return nil, fmt.Errorf("missing events claim")
}
if _, ok := claims.Events["http://schemas.openid.net/event/backchannel-logout"]; !ok {
return nil, fmt.Errorf("missing backchannel-logout event in events claim")
}
// Validate that nonce is NOT present (per spec)
if claims.Nonce != "" {
return nil, fmt.Errorf("nonce claim must not be present in logout token")
}
// Must have either sid or sub (or both)
if claims.SessionID == "" && claims.Subject == "" {
return nil, fmt.Errorf("logout token must contain either sid or sub claim")
}
return claims, nil
}
// validateLogoutTokenAudience checks if the logout token audience contains our client_id
func (t *TraefikOidc) validateLogoutTokenAudience(aud interface{}) bool {
switch v := aud.(type) {
case string:
return v == t.clientID
case []interface{}:
for _, a := range v {
if s, ok := a.(string); ok && s == t.clientID {
return true
}
}
case []string:
for _, a := range v {
if a == t.clientID {
return true
}
}
}
return false
}
// verifyLogoutTokenSignature verifies only the signature of a logout token.
// Unlike VerifyJWTSignatureAndClaims, this does NOT validate standard claims like 'exp'
// because logout tokens don't have an expiration claim per OIDC Back-Channel Logout spec.
//
// Parameters:
// - jwt: The parsed JWT structure
// - tokenString: The raw token string for signature verification
//
// Returns:
// - An error if signature verification fails
func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) error {
t.logger.Debug("Verifying logout token signature")
// Read jwksURL with RLock
t.metadataMu.RLock()
jwksURL := t.jwksURL
t.metadataMu.RUnlock()
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
if jwks == nil {
return fmt.Errorf("JWKS is nil, cannot verify token")
}
kid, ok := jwt.Header["kid"].(string)
if !ok || kid == "" {
return fmt.Errorf("missing key ID in token header")
}
alg, ok := jwt.Header["alg"].(string)
if !ok || alg == "" {
return fmt.Errorf("missing algorithm in token header")
}
// Find the matching key in JWKS
var matchingKey *JWK
for i := range jwks.Keys {
if jwks.Keys[i].Kid == kid {
matchingKey = &jwks.Keys[i]
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
if err := verifySignature(tokenString, publicKeyPEM, alg); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Logout token signature verified successfully")
return nil
}
// invalidateSession marks a session as invalidated in the session invalidation cache.
// It stores entries by both sid and sub if available.
//
// Parameters:
// - sid: The session ID to invalidate (from the 'sid' claim)
// - sub: The subject to invalidate (from the 'sub' claim)
//
// Returns:
// - An error if the invalidation fails
func (t *TraefikOidc) invalidateSession(sid, sub string) error {
if t.sessionInvalidationCache == nil {
return fmt.Errorf("session invalidation cache not initialized")
}
now := time.Now().Unix()
// Store by session ID
if sid != "" {
key := t.buildSessionInvalidationKey("sid", sid)
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
t.logger.Debugf("Invalidated session by sid: %s", sid)
}
// Store by subject (invalidates all sessions for this user)
if sub != "" {
key := t.buildSessionInvalidationKey("sub", sub)
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
t.logger.Debugf("Invalidated session by sub: %s", sub)
}
return nil
}
// isSessionInvalidated checks if a session has been invalidated via backchannel
// or front-channel logout.
//
// Parameters:
// - sid: The session ID to check
// - sub: The subject to check
// - sessionCreatedAt: When the session was created (to compare against invalidation time)
//
// Returns:
// - true if the session has been invalidated, false otherwise
func (t *TraefikOidc) isSessionInvalidated(sid, sub string, sessionCreatedAt time.Time) bool {
if t.sessionInvalidationCache == nil {
return false
}
// Truncate session creation time to seconds for fair comparison with Unix timestamps
sessionCreatedAtSec := sessionCreatedAt.Truncate(time.Second)
// Check by session ID first (more specific)
if sid != "" {
key := t.buildSessionInvalidationKey("sid", sid)
if val, found := t.sessionInvalidationCache.Get(key); found {
if invalidatedAt, ok := val.(int64); ok {
// Session was invalidated at or after it was created
invalidationTime := time.Unix(invalidatedAt, 0)
if !invalidationTime.Before(sessionCreatedAtSec) {
t.logger.Debugf("Session invalidated by sid: %s", sid)
return true
}
}
}
}
// Check by subject (all sessions for this user)
if sub != "" {
key := t.buildSessionInvalidationKey("sub", sub)
if val, found := t.sessionInvalidationCache.Get(key); found {
if invalidatedAt, ok := val.(int64); ok {
// Sessions for this subject created at or before invalidation are invalid
invalidationTime := time.Unix(invalidatedAt, 0)
if !invalidationTime.Before(sessionCreatedAtSec) {
t.logger.Debugf("Session invalidated by sub: %s", sub)
return true
}
}
}
}
return false
}
// buildSessionInvalidationKey creates a cache key for session invalidation
func (t *TraefikOidc) buildSessionInvalidationKey(keyType, value string) string {
return fmt.Sprintf("session_invalidation:%s:%s", keyType, value)
}
// extractSessionInfo extracts sid and sub from an ID token for session tracking
func (t *TraefikOidc) extractSessionInfo(idToken string) (sid, sub string, createdAt time.Time) {
if idToken == "" {
return "", "", time.Time{}
}
jwt, err := parseJWT(idToken)
if err != nil {
return "", "", time.Time{}
}
// Extract sid (session ID)
if sidVal, ok := jwt.Claims["sid"].(string); ok {
sid = sidVal
}
// Extract sub (subject)
if subVal, ok := jwt.Claims["sub"].(string); ok {
sub = subVal
}
// Extract iat for session creation time
if iatVal, ok := jwt.Claims["iat"].(float64); ok {
createdAt = time.Unix(int64(iatVal), 0)
} else {
// Default to now if iat not present
createdAt = time.Now()
}
return sid, sub, createdAt
}
// determineLogoutPath checks if the given path matches any logout URL
func (t *TraefikOidc) determineLogoutPath(path string) string {
// Check backchannel logout path
if t.backchannelLogoutPath != "" && path == t.backchannelLogoutPath {
return "backchannel"
}
// Check front-channel logout path
if t.frontchannelLogoutPath != "" && path == t.frontchannelLogoutPath {
return "frontchannel"
}
// Check regular logout path (for RP-initiated logout)
if path == t.logoutURLPath {
return "rp"
}
return ""
}
// normalizeLogoutPath ensures logout paths start with / and prevents open redirects
func normalizeLogoutPath(path string) string {
if path == "" {
return ""
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// Prevent open redirect: ensure second character is not / or \
// This prevents URLs like //example.com or /\example.com from being treated as absolute URLs
if len(path) > 1 && (path[1] == '/' || path[1] == '\\') {
// Strip leading slashes/backslashes and re-normalize
path = strings.TrimLeft(path, "/\\")
if path != "" {
path = "/" + path
}
}
return path
}
+1660
View File
File diff suppressed because it is too large Load Diff
+69 -18
View File
@@ -113,12 +113,26 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}
}
// Setup HTTP client
caPool, err := config.loadCACertPool()
if err != nil {
return nil, fmt.Errorf("failed to load CA certificates: %w", err)
}
if config.InsecureSkipVerify {
logger.Errorf("SECURITY WARNING: InsecureSkipVerify is enabled for the OIDC provider. TLS certificate verification is DISABLED. Do not use in production.")
}
var httpClient *http.Client
if config.HTTPClient != nil {
httpClient = config.HTTPClient
} else {
httpClient = CreateDefaultHTTPClient()
defaultCfg := DefaultHTTPClientConfig()
defaultCfg.RootCAs = caPool
defaultCfg.InsecureSkipVerify = config.InsecureSkipVerify
httpClient = CreatePooledHTTPClient(defaultCfg)
}
tokenCfg := TokenHTTPClientConfig()
tokenCfg.RootCAs = caPool
tokenCfg.InsecureSkipVerify = config.InsecureSkipVerify
tokenHTTPClient := CreatePooledHTTPClient(tokenCfg)
goroutineWG := &sync.WaitGroup{}
cacheManager := GetGlobalCacheManagerWithConfig(goroutineWG, config)
@@ -199,7 +213,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: cacheManager.GetSharedTokenCache(),
httpClient: httpClient,
tokenHTTPClient: CreateTokenHTTPClient(),
tokenHTTPClient: tokenHTTPClient,
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
@@ -212,16 +226,30 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}
return 60 * time.Second
}(),
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
maxRefreshTokenAge: func() time.Duration {
// 0 (or unset) disables the heuristic; negative is rejected by Validate.
if config.MaxRefreshTokenAgeSeconds > 0 {
return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second
}
return 0
}(),
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
stripAuthCookies: config.StripAuthCookies,
enableBackchannelLogout: config.EnableBackchannelLogout,
enableFrontchannelLogout: config.EnableFrontchannelLogout,
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
refreshResultCache: cacheManager.GetSharedRefreshResultCache(),
}
// Log audience configuration
@@ -240,6 +268,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
tokenResilienceConfig := DefaultTokenResilienceConfig()
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
// call, preventing the thundering herd that yields invalid_grant when the IdP
// rotates refresh tokens (Zitadel/Authentik default).
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
t.extractClaimsFunc = extractClaims
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
@@ -287,17 +320,22 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
startReplayCacheCleanup(pluginCtx, logger)
// Start memory monitoring for leak detection and performance insights
// Start memory monitoring for leak detection and performance insights.
// The interval is clamped to MinMemoryMonitorInterval (30s) inside
// StartMonitoring; tests that need deterministic sampling should call
// MemoryMonitor.Refresh() directly instead of waiting on a fast ticker.
memoryMonitor := GetGlobalMemoryMonitor()
monitorInterval := 60 * time.Second
if isTestMode() {
monitorInterval = 100 * time.Millisecond // Fast interval for tests
}
memoryMonitor.StartMonitoring(pluginCtx, monitorInterval)
memoryMonitor.StartMonitoring(pluginCtx, DefaultMemoryMonitorInterval)
logger.Debug("Started global memory monitoring")
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
// Log callback URL configuration to help diagnose redirect loop issues.
// If callbackURL is a full URL instead of a path, the callback matching
// in ServeHTTP will silently fail because req.URL.Path is compared directly.
logger.Debugf("TraefikOidc.New: callbackURL (redirURLPath) configured as: %q", t.redirURLPath)
logger.Debugf("TraefikOidc.New: logoutURLPath configured as: %q", t.logoutURLPath)
t.providerURL = config.ProviderURL
// Use singleton resource manager for metadata initialization
@@ -433,6 +471,19 @@ func (t *TraefikOidc) performDynamicClientRegistration() {
t.dcrConfig,
t.providerURL,
)
// Set up storage backend for credentials persistence
if t.dcrConfig.PersistCredentials {
cacheManager := GetGlobalCacheManagerWithConfig(t.goroutineWG, nil)
store, err := NewDCRCredentialsStore(t.dcrConfig, cacheManager, t.logger)
if err != nil {
t.logger.Errorf("Failed to create DCR credentials store: %v", err)
// Continue without persistence - registration will still work
} else {
t.dynamicClientRegistrar.SetStore(store)
t.logger.Debugf("DCR credentials store initialized with backend: %s", t.dcrConfig.StorageBackend)
}
}
}
// Get registration endpoint (from metadata or config override)
+1 -1
View File
@@ -9,7 +9,7 @@ import (
"gopkg.in/yaml.v3"
)
// Config Marshalling Tests
// Config Marshaling Tests
func TestConfig_MarshalJSON(t *testing.T) {
config := &Config{
+535 -42
View File
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
}
}
// TestServeHTTP_EventStream tests the event-stream bypass functionality
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
// handshake must skip the OIDC redirect dance (clients can't follow it
// mid-stream) but it must STILL require an authenticated session, otherwise
// any caller could reach the backend by setting Accept: text/event-stream.
func TestServeHTTP_EventStream(t *testing.T) {
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
}
})
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
// Build an authenticated session and inject its cookies onto req.
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated SSE request to be forwarded to backend")
}
if forwardedUser != "user@example.com" {
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
}
})
}
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
// handshake bypasses the OIDC redirect (clients can't follow it) but the
// session must already be authenticated, otherwise the backend is exposed
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
}))
oidc.ServeHTTP(rw, req)
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
rw := httptest.NewRecorder()
if !nextCalled {
t.Error("expected event-stream request to bypass OIDC")
}
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
}
})
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
}))
req := httptest.NewRequest("GET", "/ws", nil)
// Mixed-case + multi-token Connection header to exercise parsing.
req.Header.Set("Connection", "keep-alive, Upgrade")
req.Header.Set("Upgrade", "WebSocket")
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("ws-user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
}
if forwardedUser != "ws-user@example.com" {
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
}
})
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
// Sanity: requests without Upgrade headers must NOT hit the WS
// bypass branch (otherwise the new code path could short-circuit
// normal authentication).
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("backend must not be called for unauthenticated plain HTTP")
}))
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "keep-alive")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code == http.StatusOK {
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
}
})
}
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "successful authorization with email",
setupSession: func() *MockSessionData {
session := &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "no email triggers reauth",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "",
userIdentifier: "",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "roles and groups authorization",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "unauthorized role/group returns 403",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "template headers processing",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "OPTIONS request with CORS",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
manager: &SessionManager{logger: NewLogger("debug")},
}
// Copy values from mock to concrete session
concreteSession.SetEmail(session.email)
concreteSession.SetUserIdentifier(session.userIdentifier)
concreteSession.SetIDToken(session.idToken)
concreteSession.SetAccessToken(session.accessToken)
concreteSession.SetRefreshToken(session.refreshToken)
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
// MockSessionData is a test implementation of SessionData interface
type MockSessionData struct {
email string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
userIdentifier string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
}
func (m *MockSessionData) GetEmail() string { return m.email }
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
func (m *MockSessionData) GetIDToken() string { return m.idToken }
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
func (m *MockSessionData) SetEmail(email string) { m.email = email }
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
}
// Set up session data
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Call processAuthorizedRequest directly
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
@@ -710,3 +862,344 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
}
}
// TestStripAuthCookies tests the stripAuthCookies configuration option.
// This addresses GitHub issue #122 - OIDC cookies bloating backend requests.
func TestStripAuthCookies(t *testing.T) {
tests := []struct {
name string
stripAuthCookies bool
expectOIDCCookies bool
expectAppCookies bool
}{
{
name: "stripAuthCookies=false (default) forwards all cookies",
stripAuthCookies: false,
expectOIDCCookies: true,
expectAppCookies: true,
},
{
name: "stripAuthCookies=true strips OIDC cookies but keeps app cookies",
stripAuthCookies: true,
expectOIDCCookies: false,
expectAppCookies: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: tt.stripAuthCookies,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
// Get a valid session first (before adding fake cookies)
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Now add OIDC session cookies (simulating what the browser would send)
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_1", Value: "chunk1"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "r", Value: "refresh-token"})
// Add non-OIDC application cookies (these must always pass through)
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "app-session-id"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Check for OIDC cookies in captured cookies
hasOIDCCookie := false
hasAppSession := false
hasTheme := false
for _, c := range capturedCookies {
if len(c.Name) >= len(cookiePrefix) && c.Name[:len(cookiePrefix)] == cookiePrefix {
hasOIDCCookie = true
}
if c.Name == "my_app_session" {
hasAppSession = true
}
if c.Name == "theme" {
hasTheme = true
}
}
if tt.expectOIDCCookies && !hasOIDCCookie {
t.Error("expected OIDC cookies to be forwarded to backend")
}
if !tt.expectOIDCCookies && hasOIDCCookie {
t.Error("expected OIDC cookies to be stripped before forwarding to backend")
}
if tt.expectAppCookies && !hasAppSession {
t.Error("expected my_app_session cookie to be forwarded to backend")
}
if tt.expectAppCookies && !hasTheme {
t.Error("expected theme cookie to be forwarded to backend")
}
})
}
}
// TestStripAuthCookies_NoCookies verifies stripping works when the request has no cookies.
func TestStripAuthCookies_NoCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected no cookies, got %d", len(capturedCookies))
}
}
// TestStripAuthCookies_OnlyOIDCCookies verifies that when all cookies are OIDC cookies,
// the Cookie header is empty after stripping.
func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
var capturedCookieHeader string
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookieHeader = r.Header.Get("Cookie")
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only OIDC cookies
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected all cookies to be stripped, got %d", len(capturedCookies))
}
if capturedCookieHeader != "" {
t.Errorf("expected empty Cookie header, got %q", capturedCookieHeader)
}
}
// TestStripAuthCookies_OnlyAppCookies verifies that non-OIDC cookies pass through
// untouched when stripping is enabled.
func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only non-OIDC cookies
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "abc123"})
req.AddCookie(&http.Cookie{Name: "lang", Value: "en"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 3 {
t.Errorf("expected 3 cookies, got %d", len(capturedCookies))
}
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
for _, expected := range []string{"my_app_session", "lang", "theme"} {
if !cookieNames[expected] {
t.Errorf("expected cookie %q to be forwarded", expected)
}
}
}
// TestStripAuthCookies_CustomPrefix verifies stripping works with a custom cookie prefix.
func TestStripAuthCookies_CustomPrefix(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
// Create session manager with custom prefix
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "myapp_oidc_", 0, NewLogger("debug"))
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
customPrefix := sm.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add cookies with the custom prefix (should be stripped)
req.AddCookie(&http.Cookie{Name: customPrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: customPrefix + "s_0", Value: "chunk0"})
// Add default-prefix cookie (should NOT be stripped — different prefix)
req.AddCookie(&http.Cookie{Name: "_oidc_raczylo_m", Value: "other-session"})
// Add app cookie (should NOT be stripped)
req.AddCookie(&http.Cookie{Name: "my_app", Value: "val"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
// Custom prefix cookies should be stripped
if cookieNames[customPrefix+"m"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"m")
}
if cookieNames[customPrefix+"s_0"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"s_0")
}
// Default prefix cookie should pass through (different prefix)
if !cookieNames["_oidc_raczylo_m"] {
t.Error("expected _oidc_raczylo_m cookie to pass through (different prefix)")
}
// App cookie should pass through
if !cookieNames["my_app"] {
t.Error("expected my_app cookie to pass through")
}
}
+41 -15
View File
@@ -208,6 +208,32 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *
return m.JWKS, m.Err
}
func (m *MockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.Err != nil {
return nil, m.Err
}
if m.JWKS == nil {
return nil, fmt.Errorf("JWKS is nil")
}
for i := range m.JWKS.Keys {
k := &m.JWKS.Keys[i]
if k.Kid != kid {
continue
}
switch k.Kty {
case "RSA":
return k.ToRSAPublicKey()
case "EC":
return k.ToECDSAPublicKey()
default:
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
}
}
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
}
func (m *MockJWKCache) Cleanup() {
// Mock cleanup is a no-op - we don't want to destroy the mock JWKS data
// Real cleanup is for expired entries, not resetting all data
@@ -554,7 +580,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case to avoid replay issues
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -577,7 +603,7 @@ func TestServeHTTP(t *testing.T) {
// even if session.SetAuthenticated(true) was called.
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -634,7 +660,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -652,7 +678,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -680,7 +706,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -715,7 +741,7 @@ func TestServeHTTP(t *testing.T) {
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(nearExpiryToken)
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
},
@@ -746,7 +772,7 @@ func TestServeHTTP(t *testing.T) {
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(validToken)
session.SetIDToken(validToken) // Ensure ID token is also set
session.SetRefreshToken("should-not-be-used-refresh-token")
@@ -766,7 +792,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -788,7 +814,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -2153,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) {
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
},
expectedPath: "/original/path",
},
@@ -2730,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2756,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2783,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusForbidden,
},
@@ -2803,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2825,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{},
+24 -15
View File
@@ -9,13 +9,18 @@ import (
// LazyBackgroundTask wraps BackgroundTask to provide delayed initialization.
// This prevents memory leaks from unnecessary background tasks by starting
// them only when actually needed, reducing resource usage in idle scenarios.
//
// Lifecycle is one-shot: once Stop has been called the task cannot be
// restarted. The underlying BackgroundTask uses sync.Once for Start and
// refuses to re-run after Stop, so restart is not supported by design.
type LazyBackgroundTask struct {
// BackgroundTask is the underlying task implementation
*BackgroundTask
// started tracks whether the task has been activated
// mu guards the started flag against concurrent StartIfNeeded / Stop calls.
mu sync.Mutex
// started tracks whether the task has been activated.
// Only mutated while holding mu.
started bool
// startOnce ensures single initialization
startOnce sync.Once
}
// NewLazyBackgroundTask creates a background task that doesn't start immediately.
@@ -29,24 +34,28 @@ func NewLazyBackgroundTask(name string, interval time.Duration, taskFunc func(),
}
// StartIfNeeded starts the background task only if it hasn't been started yet.
// Uses sync.Once to ensure thread-safe single initialization.
// Safe to call concurrently. After Stop has been called this is a no-op;
// the task is not restartable.
func (lt *LazyBackgroundTask) StartIfNeeded() {
lt.startOnce.Do(func() {
if !lt.started {
lt.BackgroundTask.Start()
lt.started = true
}
})
lt.mu.Lock()
defer lt.mu.Unlock()
if lt.started {
return
}
lt.BackgroundTask.Start()
lt.started = true
}
// Stop stops the background task if it was started.
// Resets the start state to allow potential future re-initialization.
// Once stopped, the task cannot be restarted (see type doc).
func (lt *LazyBackgroundTask) Stop() {
if lt.started {
lt.BackgroundTask.Stop()
lt.started = false
lt.startOnce = sync.Once{}
lt.mu.Lock()
defer lt.mu.Unlock()
if !lt.started {
return
}
lt.BackgroundTask.Stop()
lt.started = false
}
// NewLazyCacheWithLogger creates a cache that doesn't start cleanup until first use.
+142 -12
View File
@@ -58,13 +58,21 @@ func (mpl MemoryPressureLevel) String() string {
}
}
// MemoryMonitor provides comprehensive memory monitoring and alerting
// MemoryMonitor provides comprehensive memory monitoring and alerting.
//
// Memory sampling is expensive: runtime.ReadMemStats is a stop-the-world
// operation. To keep latency predictable the monitor caches the most recent
// sample and only refreshes it when the background ticker fires, when TriggerGC
// is invoked, or when a caller explicitly calls Refresh(). GetCurrentStats is a
// cheap read of that cached sample.
type MemoryMonitor struct {
lastGCTime time.Time
startTime time.Time
lastStats *MemoryStats
cachedMemStats runtime.MemStats
logger *Logger
alertThresholds MemoryAlertThresholds
config MemoryMonitorConfig
baselineGoroutines int
baselineHeap uint64
heapGrowthRate float64
@@ -84,6 +92,30 @@ type MemoryAlertThresholds struct {
GCFrequency float64 // Alert when GC frequency exceeds this per minute
}
// MemoryMonitorConfig configures the memory monitor's scheduling behavior.
// Thresholds are kept separate in MemoryAlertThresholds.
type MemoryMonitorConfig struct {
// Interval between background samples. Must be >= MinMemoryMonitorInterval
// (30s). Values below the minimum are clamped when monitoring starts.
Interval time.Duration
}
// Default and minimum interval values. The minimum exists because
// runtime.ReadMemStats is stop-the-world and hammering it on a hot loop causes
// noticeable latency spikes, especially under yaegi.
const (
DefaultMemoryMonitorInterval = 60 * time.Second
MinMemoryMonitorInterval = 30 * time.Second
)
// DefaultMemoryMonitorConfig returns a config with sensible production
// defaults.
func DefaultMemoryMonitorConfig() MemoryMonitorConfig {
return MemoryMonitorConfig{
Interval: DefaultMemoryMonitorInterval,
}
}
// DefaultMemoryAlertThresholds returns sensible default alert thresholds
func DefaultMemoryAlertThresholds() MemoryAlertThresholds {
return MemoryAlertThresholds{
@@ -95,35 +127,82 @@ func DefaultMemoryAlertThresholds() MemoryAlertThresholds {
}
}
// NewMemoryMonitor creates a new memory monitor
// NewMemoryMonitor creates a new memory monitor using default scheduling
// configuration. See NewMemoryMonitorWithConfig for full control.
func NewMemoryMonitor(logger *Logger, thresholds MemoryAlertThresholds) *MemoryMonitor {
return NewMemoryMonitorWithConfig(logger, thresholds, DefaultMemoryMonitorConfig())
}
// NewMemoryMonitorWithConfig creates a new memory monitor with an explicit
// scheduling config.
//
// NOTE: the constructor performs a single runtime.ReadMemStats call to capture
// baseline heap / goroutine / GC counters used for leak and growth detection.
// This is a one-time stop-the-world cost at startup; all subsequent samples
// only happen on the monitoring ticker or on explicit Refresh() calls.
func NewMemoryMonitorWithConfig(logger *Logger, thresholds MemoryAlertThresholds, config MemoryMonitorConfig) *MemoryMonitor {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
if config.Interval <= 0 {
config.Interval = DefaultMemoryMonitorInterval
}
// One-time initial sample to seed baselines used for growth / leak
// detection. All subsequent sampling is gated by the monitoring ticker or
// explicit Refresh() calls.
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
return &MemoryMonitor{
mm := &MemoryMonitor{
logger: logger,
startTime: time.Now(),
alertThresholds: thresholds,
config: config,
baselineHeap: memStats.HeapAlloc,
baselineGoroutines: runtime.NumGoroutine(),
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
lastGCCount: memStats.NumGC,
}
mm.cachedMemStats = memStats
return mm
}
// GetCurrentStats collects current memory statistics
// GetCurrentStats returns the most recently sampled memory statistics.
//
// This is a cheap cached read: it does NOT call runtime.ReadMemStats. Samples
// are refreshed only by the monitoring ticker or by an explicit call to
// Refresh(). If no sample has been produced yet, stats derived from the
// constructor-time raw sample are returned (with no additional STW cost).
func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
mm.mu.RLock()
stats := mm.lastStats
mm.mu.RUnlock()
if stats != nil {
return stats
}
return mm.buildStatsFromCache()
}
// Refresh synchronously samples current memory statistics via
// runtime.ReadMemStats and updates the cached value. This is the only path
// (other than the monitoring ticker and TriggerGC) that pays the stop-the-world
// cost. Use it in tests or in callers that explicitly need a fresh sample.
func (mm *MemoryMonitor) Refresh() *MemoryStats {
return mm.sample()
}
// sample performs a stop-the-world ReadMemStats, updates the cached raw stats,
// computes a derived MemoryStats snapshot, and stores it as lastStats.
func (mm *MemoryMonitor) sample() *MemoryStats {
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
now := time.Now()
// Calculate GC frequency
// Calculate GC frequency relative to the previous snapshot.
gcFrequency := 0.0
mm.mu.RLock()
lastStats := mm.lastStats
@@ -168,6 +247,7 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
mm.updateHeapGrowthTracking(stats)
mm.mu.Lock()
mm.cachedMemStats = memStats
mm.lastStats = stats
mm.lastGCCount = memStats.NumGC
mm.mu.Unlock()
@@ -175,6 +255,35 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
return stats
}
// buildStatsFromCache constructs a MemoryStats snapshot from the cached raw
// runtime.MemStats without issuing a new ReadMemStats call. Used as a fallback
// when GetCurrentStats is called before the first sample() has completed.
func (mm *MemoryMonitor) buildStatsFromCache() *MemoryStats {
mm.mu.RLock()
memStats := mm.cachedMemStats
mm.mu.RUnlock()
stats := &MemoryStats{
HeapAllocBytes: memStats.HeapAlloc,
HeapSysBytes: memStats.HeapSys,
HeapIdleBytes: memStats.HeapIdle,
HeapInuseBytes: memStats.HeapInuse,
HeapReleasedBytes: memStats.HeapReleased,
HeapObjects: memStats.HeapObjects,
StackInuseBytes: memStats.StackInuse,
StackSysBytes: memStats.StackSys,
GCSysBytes: memStats.GCSys,
NumGoroutines: runtime.NumGoroutine(),
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
GCFrequency: 0.0,
Timestamp: time.Now(),
}
mm.collectApplicationStats(stats)
stats.MemoryPressure = mm.calculateMemoryPressure(stats)
return stats
}
// collectApplicationStats gathers application-specific memory stats
func (mm *MemoryMonitor) collectApplicationStats(stats *MemoryStats) {
// Get session count from ChunkManager if available
@@ -229,7 +338,7 @@ func (mm *MemoryMonitor) updateGoroutineTracking(stats *MemoryStats) {
}
// Check for potential goroutine leak
if stats.NumGoroutines > mm.baselineGoroutines+int(mm.alertThresholds.GoroutineCount) {
if stats.NumGoroutines > mm.baselineGoroutines+mm.alertThresholds.GoroutineCount {
mm.mu.Lock()
wasAlert := mm.goroutineLeakAlert
if !wasAlert {
@@ -302,7 +411,16 @@ var (
globalMonitoringMutex sync.Mutex
)
// StartMonitoring starts continuous memory monitoring as a global singleton
// StartMonitoring starts continuous memory monitoring as a global singleton.
//
// The effective interval is resolved as follows:
// 1. If the caller passes a positive interval, that is used.
// 2. Otherwise the configured MemoryMonitorConfig.Interval is used.
// 3. Otherwise the built-in default (60s) is used.
//
// The result is then clamped to a minimum of MinMemoryMonitorInterval (30s) to
// avoid stop-the-world ReadMemStats storms. Callers that need rapid updates in
// tests should call Refresh() directly instead of spinning the ticker fast.
func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Duration) {
globalMonitoringMutex.Lock()
defer globalMonitoringMutex.Unlock()
@@ -316,7 +434,17 @@ func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Dura
}
if interval <= 0 {
interval = 30 * time.Second
interval = mm.config.Interval
}
if interval <= 0 {
interval = DefaultMemoryMonitorInterval
}
if interval < MinMemoryMonitorInterval {
if !isTestMode() {
mm.logger.Debug("Memory monitor interval %v is below minimum %v; clamping",
interval, MinMemoryMonitorInterval)
}
interval = MinMemoryMonitorInterval
}
registry := GetGlobalTaskRegistry()
@@ -325,7 +453,7 @@ func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Dura
"memory-monitor",
interval,
func() {
stats := mm.GetCurrentStats()
stats := mm.sample()
mm.LogMemoryStats(stats)
mm.checkAlerts(stats)
},
@@ -369,14 +497,16 @@ func (mm *MemoryMonitor) checkAlerts(stats *MemoryStats) {
}
}
// TriggerGC forces garbage collection and logs the impact
// TriggerGC forces garbage collection and logs the impact. Both the before and
// after measurements are fresh samples (explicit Refresh() calls) because the
// comparison is meaningless against a stale cached snapshot.
func (mm *MemoryMonitor) TriggerGC() {
before := mm.GetCurrentStats()
before := mm.Refresh()
runtime.GC()
runtime.GC() // Run twice to ensure full collection
after := mm.GetCurrentStats()
after := mm.Refresh()
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
freedBytes := int64(before.HeapAllocBytes) - int64(after.HeapAllocBytes)
+273 -60
View File
@@ -13,6 +13,99 @@ import (
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// bypassReason describes why a request is being forwarded without OIDC auth.
// It is only used for logging and to decide whether extra side-effects
// (propagating the user header from an existing session) should run.
const (
bypassReasonExcluded = "excluded-url"
bypassReasonSSE = "sse"
bypassReasonWebSocket = "websocket"
)
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
// (RFC 6455). The middleware can only see the handshake; once Traefik
// completes the upgrade it forwards frames directly, so we never re-process
// per-frame traffic. We bypass auth on the handshake the same way we do for
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
func isWebSocketUpgrade(req *http.Request) bool {
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
return false
}
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
return true
}
}
return false
}
// shouldBypassAuth decides whether a request must skip OIDC authentication
// entirely. It returns (true, reason) when either the request path matches a
// configured excluded URL, the Accept header asks for a text/event-stream
// response (SSE), or the request is a WebSocket upgrade handshake. The
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
// kind (e.g. propagating user headers).
//
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
// WebSocket traffic is never blocked by a slow/broken provider.
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
if t.determineExcludedURL(req.URL.Path) {
return true, bypassReasonExcluded
}
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
return true, bypassReasonSSE
}
if isWebSocketUpgrade(req) {
return true, bypassReasonWebSocket
}
return false, ""
}
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
// requests and, on success, copies the authenticated user's identity onto
// the outgoing request so downstream services can see who the user is.
//
// Returns true when the request carries a valid authenticated session and
// the bypass should proceed. Returns false when no usable session is
// present; callers must then reject the request (typically with 401) to
// prevent unauthenticated traffic from reaching the backend just by setting
// `Accept: text/event-stream` or sending a WebSocket upgrade.
//
// The check is cookie-only: the session cookie is sealed by our encryption
// key, so the authenticated flag cannot be forged. We do NOT run full token
// signature verification here so that SSE/WS keeps working when the OIDC
// provider is briefly unavailable for JWK fetches.
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
if t.sessionManager == nil {
return false
}
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
return false
}
defer session.returnToPoolSafely()
if !session.GetAuthenticated() {
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
return false
}
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
return false
}
req.Header.Set("X-Forwarded-User", userIdentifier)
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-User", userIdentifier)
}
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
return true
}
// ServeHTTP implements the main middleware logic for processing HTTP requests.
// It handles the complete OIDC authentication flow including:
// - Excluded URL bypass
@@ -26,6 +119,31 @@ import (
// - rw: The HTTP response writer.
// - req: The incoming HTTP request.
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Log request entry for debugging routing issues
t.logger.Debugf("Incoming request: %s %s", req.Method, req.URL.Path)
// Handle logout requests early - before waiting for OIDC initialization
// This allows users to logout even if the OIDC provider is unavailable
if req.URL.Path == t.logoutURLPath {
t.logger.Debugf("Logout path matched early: %s", req.URL.Path)
t.handleLogout(rw, req)
return
}
// Handle backchannel logout (IdP-initiated POST with logout_token)
if t.enableBackchannelLogout && t.backchannelLogoutPath != "" && req.URL.Path == t.backchannelLogoutPath {
t.logger.Debug("Backchannel logout path matched")
t.handleBackchannelLogout(rw, req)
return
}
// Handle front-channel logout (IdP-initiated GET with sid/iss in iframe)
if t.enableFrontchannelLogout && t.frontchannelLogoutPath != "" && req.URL.Path == t.frontchannelLogoutPath {
t.logger.Debug("Front-channel logout path matched")
t.handleFrontchannelLogout(rw, req)
return
}
if !strings.HasPrefix(req.URL.Path, "/health") {
t.firstRequestMutex.Lock()
if !t.firstRequestReceived {
@@ -42,6 +160,43 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.firstRequestMutex.Unlock()
}
// Evaluate auth-bypass once, before waiting for initialization. Excluded
// URLs, SSE and WebSocket upgrade requests must not block on provider
// init. For SSE/WebSocket we ALSO require an authenticated session
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
// of in-flight streams can't follow an OIDC redirect, so forwarding
// unauthenticated traffic would silently expose the backend.
if bypass, reason := t.shouldBypassAuth(req); bypass {
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
switch reason {
case bypassReasonExcluded:
// Operator-declared excluded URLs forward unconditionally.
t.next.ServeHTTP(rw, req)
case bypassReasonSSE, bypassReasonWebSocket:
// Skip the OIDC redirect dance (clients can't follow it
// mid-stream) but still require an authenticated session.
// Otherwise an unauthenticated client could hit the backend
// just by setting Accept: text/event-stream or sending a
// WebSocket upgrade.
if !t.applyBypassUserHeaders(req, reason) {
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
return
}
t.next.ServeHTTP(rw, req)
default:
t.next.ServeHTTP(rw, req)
}
return
}
// Log waiting for initialization to help diagnose hanging requests
t.logger.Debug("Waiting for OIDC provider initialization...")
// time.NewTimer + Stop avoids leaking a goroutine+channel for 30s on every
// request when initComplete fires quickly (would happen with time.After).
initTimer := time.NewTimer(30 * time.Second)
defer initTimer.Stop()
select {
case <-t.initComplete:
// Read issuerURL with RLock
@@ -72,24 +227,13 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Request canceled while waiting for OIDC initialization")
t.sendErrorResponse(rw, req, "Request canceled", http.StatusRequestTimeout)
return
case <-time.After(30 * time.Second):
case <-initTimer.C:
t.logger.Error("Timeout waiting for OIDC initialization")
t.sendErrorResponse(rw, req, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable)
return
}
if t.determineExcludedURL(req.URL.Path) {
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
t.next.ServeHTTP(rw, req)
return
}
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/event-stream") {
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
t.next.ServeHTTP(rw, req)
return
}
// Bypass checks already ran before the init wait; no need to repeat them.
t.sessionManager.CleanupOldCookies(rw, req)
session, err := t.sessionManager.GetSession(req)
@@ -107,6 +251,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.sendErrorResponse(rw, req, "Critical session error", http.StatusInternalServerError)
return
}
// Sub-resource requests (script/image/fetch/serviceWorker) must not
// trigger an OIDC redirect from this path either: they would overwrite
// any in-flight CSRF/nonce in the session. Let the next HTML navigation
// initiate the flow. See issue #129.
if t.isAjaxRequest(req) || t.isNonNavigationRequest(req) {
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
return
}
scheme := utils.DetermineScheme(req, t.forceHTTPS)
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
@@ -120,14 +272,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Check if the current request is the OIDC callback
t.logger.Debugf("Checking callback URL match: request_path=%q, configured_callback=%q", req.URL.Path, t.redirURLPath)
if req.URL.Path == t.redirURLPath {
t.logger.Debugf("Callback URL matched, processing OIDC callback (redirect_url=%s)", redirectURL)
t.handleCallback(rw, req, redirectURL)
return
}
t.logger.Debugf("Callback URL did not match (request_path=%q != configured=%q), continuing auth flow", req.URL.Path, t.redirURLPath)
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
@@ -137,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
userIdentifier := session.GetUserIdentifier()
// User authorization check
if authenticated && userIdentifier != "" {
if !t.isAllowedUser(userIdentifier) {
@@ -160,8 +312,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
refreshTokenPresent := session.GetRefreshToken() != ""
// Check if this is an AJAX request that should receive 401 instead of redirect
isAjaxRequest := t.isAjaxRequest(req)
// Decide whether to answer with 401 instead of a redirect. AJAX requests
// cannot follow a 302 into an IdP, and sub-resource loads (script/image/
// fetch/serviceWorker) must not trigger a fresh OIDC flow because parallel
// loads would each overwrite the session CSRF/nonce (issue #129). Only
// top-level HTML navigations should redirect.
isAjaxRequest := t.isAjaxRequest(req) || t.isNonNavigationRequest(req)
// Check if refresh token is likely expired (older than 6 hours)
refreshTokenExpired := refreshTokenPresent && t.isRefreshTokenExpired(session)
@@ -205,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
refreshed := t.refreshToken(rw, req, session)
if refreshed {
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
userIdentifier = session.GetUserIdentifier()
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
@@ -255,40 +411,79 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// - session: The user's session data containing tokens and claims.
// - redirectURL: The callback URL for re-authentication if needed.
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
email := session.GetEmail()
if email == "" {
t.logger.Info("No email found in session during final processing, initiating re-auth")
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
// Reset redirect count to prevent loops when session is invalid
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
tokenForClaims := session.GetIDToken()
if tokenForClaims == "" {
tokenForClaims = session.GetAccessToken()
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
// Reset redirect count to prevent loops when token is missing
// Check if session has been invalidated via backchannel or front-channel logout
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
idToken := session.GetIDToken()
if idToken != "" {
sid, sub, createdAt := t.extractSessionInfo(idToken)
if t.isSessionInvalidated(sid, sub, createdAt) {
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
// Clear the session and redirect to login
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing invalidated session: %v", err)
}
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
}
// Resolve ID-token claims at most once per request. SessionData caches
// the parsed claims keyed on the raw ID token, so concurrent dashboard
// panel requests on the same session don't repeatedly base64-decode and
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
// that hosts Traefik plugins). idClaims is reused below by the
// header-templates branch.
idToken := session.GetIDToken()
var (
idClaims map[string]interface{}
idClaimsErr error
)
if idToken != "" {
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
}
// Choose which claims drive groups/roles extraction. Prefer the ID
// token (cached) and fall back to the access token if there is no ID
// token in the session — matching the prior behavior for opaque
// ID-token providers.
var (
groupClaims map[string]interface{}
groupClaimsErr error
)
if idToken != "" {
groupClaims, groupClaimsErr = idClaims, idClaimsErr
} else if accessToken := session.GetAccessToken(); accessToken != "" {
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
} else if len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
var groups, roles []string
if groupClaimsErr == nil && groupClaims != nil {
var err error
groups, roles, err = t.extractGroupsAndRolesFromClaims(groupClaims)
if err != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
// Initialize empty slices
var groups, roles []string
if tokenForClaims != "" {
var err error
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
if err != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
// Reset redirect count to prevent loops when claim extraction fails
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
} else if err == nil {
if err == nil {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
@@ -307,51 +502,53 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
req.Header.Set("X-Forwarded-User", userIdentifier)
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-User", userIdentifier)
if idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
}
if len(t.headerTemplates) > 0 {
claims, err := t.extractClaimsFunc(session.GetIDToken())
if err != nil {
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
if idClaimsErr != nil {
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", idClaimsErr)
} else {
// idClaims may be nil when no ID token is present; templates
// referencing .Claims.* will simply produce empty values, which
// matches the prior behavior.
templateData := map[string]interface{}{
"AccessToken": session.GetAccessToken(),
"IDToken": session.GetIDToken(),
"IDToken": idToken,
"RefreshToken": session.GetRefreshToken(),
"Claims": claims,
"Claims": idClaims,
}
for headerName, tmpl := range t.headerTemplates {
var buf bytes.Buffer
if err := tmpl.Execute(&buf, templateData); err != nil {
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
continue
}
headerValue := buf.String()
req.Header.Set(headerName, headerValue)
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
}
session.MarkDirty()
t.logger.Debugf("Session marked dirty after templated header processing.")
// NOTE: templates only mutate request headers (not session state),
// so we deliberately do NOT MarkDirty / Save here. Previously every
// authenticated request with header templates re-encrypted and
// rewrote all session cookies, which was a measurable CPU and
// Set-Cookie tax on dashboards that poll many panels per second.
}
}
@@ -374,7 +571,23 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
}
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
// Strip OIDC session cookies before forwarding to the backend to prevent
// HTTP 431 "Request Header Fields Too Large" errors (GitHub issue #122).
if t.stripAuthCookies {
prefix := t.sessionManager.GetCookiePrefix()
filtered := make([]*http.Cookie, 0, len(req.Cookies()))
for _, c := range req.Cookies() {
if !strings.HasPrefix(c.Name, prefix) {
filtered = append(filtered, c)
}
}
req.Header.Del("Cookie")
for _, c := range filtered {
req.AddCookie(c)
}
}
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier)
t.next.ServeHTTP(rw, req)
}
+39 -7
View File
@@ -95,6 +95,38 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
}
}
// TestLogoutWorksWithoutOIDCInitialization tests that logout works even if OIDC provider is unavailable
// This is critical for allowing users to clear their session when the provider is down
func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
logoutURLPath: "/logout",
postLogoutRedirectURI: "/",
forceHTTPS: false,
}
// Note: initComplete is NOT closed, simulating OIDC provider being unavailable
req := httptest.NewRequest("GET", "/logout", nil)
req.Host = "example.com"
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Should redirect to post-logout URI even without OIDC initialization
if rw.Code != http.StatusFound {
t.Errorf("Expected redirect (302) for logout, got %d", rw.Code)
}
location := rw.Header().Get("Location")
if location == "" {
t.Error("Expected Location header for logout redirect")
}
}
// TestMiddlewareDomainRestrictions tests domain-based access control
// NOTE: Currently commented out due to complex session setup requirements
// These scenarios are tested indirectly through integration tests
@@ -129,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
// Create authenticated session
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
session.SetIDToken("dummy-token")
session.Save(req, httptest.NewRecorder())
@@ -171,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
// Create session with forbidden domain
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@forbidden.com")
session.SetUserIdentifier("user@forbidden.com")
session.SetAuthenticated(true)
// Save and inject cookies
@@ -220,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
// Create session with opaque token
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
session.SetAuthenticated(true)
@@ -259,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("") // No email
session.SetUserIdentifier("") // No email
session.SetIDToken("dummy-token")
rw := httptest.NewRecorder()
@@ -289,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetIDToken("") // No ID token
session.SetAccessToken("") // No access token
@@ -317,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetIDToken("dummy-token")
rw := httptest.NewRecorder()
@@ -351,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
testEmail := "user@example.com"
session.SetEmail(testEmail)
session.SetUserIdentifier(testEmail)
session.SetIDToken("dummy-id-token")
rw := httptest.NewRecorder()
+34 -44
View File
@@ -18,7 +18,6 @@ type RefreshCoordinator struct {
inFlightRefreshes map[string]*refreshOperation
cleanupTimers map[string]*time.Timer
sessionRefreshAttempts map[string]*refreshAttemptTracker
delayedCleanupQueue chan delayedCleanupItem
circuitBreaker *RefreshCircuitBreaker
metrics *RefreshMetrics
logger *Logger
@@ -107,12 +106,6 @@ type RefreshMetrics struct {
currentInFlightRefreshes int32
}
// delayedCleanupItem represents an item scheduled for delayed cleanup
type delayedCleanupItem struct {
cleanupAt time.Time
tokenHash string
}
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
type RefreshCircuitBreaker struct {
lastFailureTime time.Time
@@ -143,7 +136,6 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
metrics: &RefreshMetrics{},
logger: logger,
stopChan: make(chan struct{}),
delayedCleanupQueue: make(chan delayedCleanupItem, 1000), // Buffered channel for cleanup items
cleanupTimers: make(map[string]*time.Timer),
circuitBreaker: &RefreshCircuitBreaker{
config: RefreshCircuitBreakerConfig{
@@ -158,10 +150,6 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
rc.wg.Add(1)
go rc.cleanupRoutine()
// Start delayed cleanup processor (single goroutine processes all cleanup timers)
rc.wg.Add(1)
go rc.processDelayedCleanups()
return rc
}
@@ -234,7 +222,7 @@ func (rc *RefreshCoordinator) CoordinateRefresh(
// Returns (operation, false, nil) if joined an existing operation
// Returns (nil, false, error) if the operation was rejected
func (rc *RefreshCoordinator) getOrCreateOperation(
ctx context.Context,
_ context.Context,
sessionID string,
tokenHash string,
refreshToken string,
@@ -293,7 +281,7 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
// executeRefreshAsync performs the actual refresh operation asynchronously
func (rc *RefreshCoordinator) executeRefreshAsync(
operation *refreshOperation,
sessionID string,
_ string, // sessionID - reserved for future metrics/logging
tokenHash string,
refreshFunc func() (*TokenResponse, error),
) {
@@ -377,35 +365,19 @@ func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) {
rc.cleanupTimerMu.Unlock()
}
// performCleanup removes the operation from the in-flight map
// performCleanup removes the operation from the in-flight map.
// Idempotent: only decrements the in-flight counter if an entry was actually
// removed. This guards against any future path accidentally calling cleanup
// twice for the same tokenHash (which would corrupt the refresh budget).
func (rc *RefreshCoordinator) performCleanup(tokenHash string) {
rc.refreshMutex.Lock()
delete(rc.inFlightRefreshes, tokenHash)
_, existed := rc.inFlightRefreshes[tokenHash]
if existed {
delete(rc.inFlightRefreshes, tokenHash)
}
rc.refreshMutex.Unlock()
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
}
// processDelayedCleanups processes delayed cleanup requests from the queue
// This is a single goroutine that handles all delayed cleanups
func (rc *RefreshCoordinator) processDelayedCleanups() {
defer rc.wg.Done()
for {
select {
case item := <-rc.delayedCleanupQueue:
// Wait until cleanup time
waitDuration := time.Until(item.cleanupAt)
if waitDuration > 0 {
select {
case <-time.After(waitDuration):
case <-rc.stopChan:
return
}
}
rc.performCleanup(item.tokenHash)
case <-rc.stopChan:
return
}
if existed {
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
}
}
@@ -494,15 +466,33 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
// hashRefreshToken creates a hash of the refresh token for deduplication
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
return refreshCoordinatorSessionID(token)
}
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
// for both deduplication and per-session attempt tracking. Using sha256 of the
// raw token means each rotation produces a fresh sessionID with its own attempt
// budget, which is what we want.
func refreshCoordinatorSessionID(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// isUnderMemoryPressure checks if the system is under memory pressure
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
// coordinated refresh result. It is wider than RefreshTimeout so a follower
// always sees the leader's result instead of timing out independently.
const refreshCoordinatorWaitTimeout = 35 * time.Second
// isUnderMemoryPressure checks if the system is under memory pressure by
// consulting the global memory monitor. Returns true when pressure reaches
// High or Critical, at which point we refuse new refresh operations to
// avoid aggravating an already-stressed heap.
func (rc *RefreshCoordinator) isUnderMemoryPressure() bool {
// This is a simplified check - in production you'd want to use runtime.MemStats
// or system-specific memory monitoring
return false // Placeholder - implement actual memory check
monitor := GetGlobalMemoryMonitor()
if monitor == nil {
return false
}
return monitor.GetMemoryPressure() >= MemoryPressureHigh
}
// cleanupRoutine periodically cleans up stale tracking entries
+164
View File
@@ -0,0 +1,164 @@
package traefikoidc
import (
"context"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
// stubTokenExchanger lets us count how many upstream refresh-token grants
// happen for a given refresh_token across concurrent middleware-level calls.
type stubTokenExchanger struct {
calls int32
delay time.Duration
resp *TokenResponse
}
func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
return nil, nil
}
func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
atomic.AddInt32(&s.calls, 1)
if s.delay > 0 {
time.Sleep(s.delay)
}
return s.resp, nil
}
func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error {
return nil
}
// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many
// concurrent calls to coordinatedTokenRefresh with the same refresh token
// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call.
//
// Without the wireup this assertion fails (one upstream call per goroutine).
func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) {
stub := &stubTokenExchanger{
delay: 100 * time.Millisecond,
resp: &TokenResponse{
AccessToken: "new_access",
RefreshToken: "new_refresh",
IDToken: "new_id",
ExpiresIn: 3600,
},
}
logger := NewLogger("error")
cfg := DefaultRefreshCoordinatorConfig()
cfg.MaxRefreshAttempts = 10000
cfg.MaxConcurrentRefreshes = 32
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: stub,
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
}
defer oidc.refreshCoordinator.Shutdown()
const concurrency = 50
var wg sync.WaitGroup
wg.Add(concurrency)
req := httptest.NewRequest("GET", "/", nil)
start := make(chan struct{})
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
<-start
resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token")
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if resp == nil || resp.AccessToken != "new_access" {
t.Errorf("unexpected response: %+v", resp)
}
}()
}
close(start)
wg.Wait()
got := atomic.LoadInt32(&stub.calls)
// Up to 2 is acceptable to absorb the documented timing slack in the
// existing coordinator tests (e.g. operation just cleaned up before a
// late goroutine reads the in-flight map). Anything beyond that means
// coalescing is broken.
if got > 2 {
t.Fatalf("expected <=2 upstream refresh calls, got %d", got)
}
}
// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil
// coordinator path so existing tests that build TraefikOidc literals stay
// green.
func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) {
stub := &stubTokenExchanger{
resp: &TokenResponse{AccessToken: "ok"},
}
oidc := &TraefikOidc{
logger: NewLogger("error"),
tokenExchanger: stub,
// refreshCoordinator deliberately nil
}
resp, err := oidc.coordinatedTokenRefresh(nil, "rt")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil || resp.AccessToken != "ok" {
t.Fatalf("unexpected response: %+v", resp)
}
if got := atomic.LoadInt32(&stub.calls); got != 1 {
t.Fatalf("expected exactly 1 upstream call, got %d", got)
}
}
// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that
// distinct refresh tokens are not falsely coalesced.
func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) {
stub := &stubTokenExchanger{
delay: 20 * time.Millisecond,
resp: &TokenResponse{AccessToken: "ok"},
}
logger := NewLogger("error")
cfg := DefaultRefreshCoordinatorConfig()
cfg.MaxRefreshAttempts = 10000
cfg.MaxConcurrentRefreshes = 32
cfg.DeduplicationCleanupDelay = 0
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: stub,
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
}
defer oidc.refreshCoordinator.Shutdown()
const distinct = 8
var wg sync.WaitGroup
wg.Add(distinct)
for i := 0; i < distinct; i++ {
i := i
go func() {
defer wg.Done()
_, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i))))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
}
wg.Wait()
if got := atomic.LoadInt32(&stub.calls); int(got) != distinct {
t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got)
}
}
+186
View File
@@ -0,0 +1,186 @@
package traefikoidc
import (
"context"
"errors"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
// inMemoryCache is the smallest CacheInterface that satisfies the cross-
// replica dedup contract: Set/Get with TTL. Used in place of the universal
// cache singleton so these tests stay hermetic.
type inMemoryCache struct {
entries map[string]inMemoryCacheEntry
mu sync.Mutex
}
type inMemoryCacheEntry struct {
expiresAt time.Time
value interface{}
}
func newInMemoryCache() *inMemoryCache {
return &inMemoryCache{entries: make(map[string]inMemoryCacheEntry)}
}
func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.entries[key] = inMemoryCacheEntry{value: value, expiresAt: time.Now().Add(ttl)}
}
func (c *inMemoryCache) Get(key string) (any, bool) {
c.mu.Lock()
defer c.mu.Unlock()
e, ok := c.entries[key]
if !ok {
return nil, false
}
if time.Now().After(e.expiresAt) {
delete(c.entries, key)
return nil, false
}
return e.value, true
}
func (c *inMemoryCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.entries, key)
}
func (c *inMemoryCache) SetMaxSize(int) {}
func (c *inMemoryCache) Cleanup() {}
func (c *inMemoryCache) Close() {}
func (c *inMemoryCache) Size() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.entries)
}
func (c *inMemoryCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.entries = map[string]inMemoryCacheEntry{}
}
func (c *inMemoryCache) GetStats() map[string]any { return map[string]any{} }
// erroringTokenExchanger always errors - simulates an IdP rejection.
type erroringTokenExchanger struct {
calls int32
}
func (e *erroringTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
return nil, errors.New("not used")
}
func (e *erroringTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
atomic.AddInt32(&e.calls, 1)
return nil, errors.New("invalid_grant")
}
func (e *erroringTokenExchanger) RevokeTokenWithProvider(_, _ string) error { return nil }
// TestCoordinatedTokenRefresh_CrossReplicaCacheHit simulates a peer Traefik
// replica having just refreshed: the shared cache already has the result, so
// this pod must reuse it without ever calling the IdP.
func TestCoordinatedTokenRefresh_CrossReplicaCacheHit(t *testing.T) {
stub := &stubTokenExchanger{
resp: &TokenResponse{AccessToken: "should_not_be_called"},
}
logger := NewLogger("error")
cache := newInMemoryCache()
preExisting := &TokenResponse{
AccessToken: "from_peer",
RefreshToken: "rotated_by_peer",
IDToken: "id_from_peer",
}
rt := "shared_refresh_token"
cache.Set(refreshResultCacheKey(refreshCoordinatorSessionID(rt)), preExisting, refreshResultCacheTTL)
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: stub,
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
refreshResultCache: cache,
}
defer oidc.refreshCoordinator.Shutdown()
resp, err := oidc.coordinatedTokenRefresh(httptest.NewRequest("GET", "/", nil), rt)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil || resp.AccessToken != "from_peer" {
t.Fatalf("expected peer-provided response, got %+v", resp)
}
if got := atomic.LoadInt32(&stub.calls); got != 0 {
t.Fatalf("expected 0 upstream calls (peer already refreshed), got %d", got)
}
}
// TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache verifies that on a
// cache miss the leader stores its result for peers to find within the TTL.
func TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache(t *testing.T) {
stub := &stubTokenExchanger{
resp: &TokenResponse{AccessToken: "fresh_grant"},
}
logger := NewLogger("error")
cache := newInMemoryCache()
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: stub,
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
refreshResultCache: cache,
}
defer oidc.refreshCoordinator.Shutdown()
rt := "fresh_refresh_token"
resp, err := oidc.coordinatedTokenRefresh(nil, rt)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil || resp.AccessToken != "fresh_grant" {
t.Fatalf("unexpected response: %+v", resp)
}
if got := atomic.LoadInt32(&stub.calls); got != 1 {
t.Fatalf("expected 1 upstream call, got %d", got)
}
v, ok := cache.Get(refreshResultCacheKey(refreshCoordinatorSessionID(rt)))
if !ok {
t.Fatal("expected refresh result to be cached after upstream success")
}
if tr, ok := v.(*TokenResponse); !ok || tr.AccessToken != "fresh_grant" {
t.Fatalf("cached value malformed: %+v", v)
}
}
// TestCoordinatedTokenRefresh_ErrorIsNotCached makes sure we don't poison the
// dedup cache when the IdP rejects the grant. Peers must run their own
// refresh; they cannot inherit an error.
func TestCoordinatedTokenRefresh_ErrorIsNotCached(t *testing.T) {
failing := &erroringTokenExchanger{}
logger := NewLogger("error")
cache := newInMemoryCache()
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: failing,
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
refreshResultCache: cache,
}
defer oidc.refreshCoordinator.Shutdown()
if _, err := oidc.coordinatedTokenRefresh(nil, "doomed_refresh_token"); err == nil {
t.Fatal("expected an error from the failing exchanger")
}
if cache.Size() != 0 {
t.Fatalf("error result must not be cached, size=%d", cache.Size())
}
}
+68
View File
@@ -0,0 +1,68 @@
package traefikoidc
import (
"testing"
"time"
"github.com/gorilla/sessions"
)
// sessionWithIssuedAt builds the smallest SessionData that GetRefreshTokenIssuedAt
// reads from. We can't reuse sessionPool.Get() here because that requires a
// fully initialized SessionManager - overkill for this unit-level check.
func sessionWithIssuedAt(t *testing.T, issuedAt time.Time) *SessionData {
t.Helper()
rs := sessions.NewSession(nil, "refresh")
if !issuedAt.IsZero() {
rs.Values["issued_at"] = issuedAt.Unix()
}
return &SessionData{
refreshSession: rs,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
idTokenChunks: make(map[int]*sessions.Session),
}
}
func TestIsRefreshTokenExpired_DisabledWhenAgeZero(t *testing.T) {
tr := &TraefikOidc{maxRefreshTokenAge: 0}
sd := sessionWithIssuedAt(t, time.Now().Add(-30*24*time.Hour))
if tr.isRefreshTokenExpired(sd) {
t.Fatal("expected isRefreshTokenExpired=false when maxRefreshTokenAge is 0")
}
}
func TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp(t *testing.T) {
tr := &TraefikOidc{maxRefreshTokenAge: time.Hour}
sd := sessionWithIssuedAt(t, time.Time{}) // no issued_at value
if tr.isRefreshTokenExpired(sd) {
t.Fatal("expected isRefreshTokenExpired=false when issued_at missing (legacy session)")
}
}
func TestIsRefreshTokenExpired_WithinWindow(t *testing.T) {
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
sd := sessionWithIssuedAt(t, time.Now().Add(-1*time.Hour))
if tr.isRefreshTokenExpired(sd) {
t.Fatal("expected isRefreshTokenExpired=false within max age")
}
}
func TestIsRefreshTokenExpired_BeyondWindow(t *testing.T) {
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
sd := sessionWithIssuedAt(t, time.Now().Add(-7*time.Hour))
if !tr.isRefreshTokenExpired(sd) {
t.Fatal("expected isRefreshTokenExpired=true beyond max age")
}
}
func TestIsRefreshTokenExpired_NilGuards(t *testing.T) {
var tr *TraefikOidc
if tr.isRefreshTokenExpired(nil) {
t.Fatal("nil receiver must not panic and must return false")
}
tr = &TraefikOidc{maxRefreshTokenAge: time.Hour}
if tr.isRefreshTokenExpired(nil) {
t.Fatal("nil session must return false")
}
}
+2 -2
View File
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
// Simulate successful Azure authentication
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Azure may use opaque access tokens
session.SetAccessToken("opaque-azure-access-token")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
require.NoError(t, err)
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
assert.Equal(t, "user@example.com", session2.GetEmail())
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
+2 -2
View File
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
// Set up the attacker's session with malicious data
attackerSession.SetAuthenticated(true)
attackerSession.SetEmail("attacker@evil.com")
attackerSession.SetUserIdentifier("attacker@evil.com")
attackerSession.SetIDToken(ValidIDToken)
attackerSession.SetAccessToken(ValidAccessToken)
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
}
// Get the email from the session
email := session.GetEmail()
email := session.GetUserIdentifier()
w.Header().Set("X-User-Email", email)
w.WriteHeader(http.StatusOK)
})
+90 -42
View File
@@ -100,7 +100,7 @@ type combinedSessionPayload struct {
A string `json:"a,omitempty"`
R string `json:"r,omitempty"`
I string `json:"i,omitempty"`
E string `json:"e,omitempty"`
Ui string `json:"ui,omitempty"`
Cs string `json:"cs,omitempty"`
N string `json:"n,omitempty"`
Cv string `json:"cv,omitempty"`
@@ -113,11 +113,11 @@ type combinedSessionPayload struct {
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
// All other mainSession.Values keys are stored in the X (extra) field.
var knownSessionKeys = map[string]bool{
"access_token": true,
"refresh_token": true,
"id_token": true,
"email": true,
"authenticated": true,
"access_token": true,
"refresh_token": true,
"id_token": true,
"user_identifier": true,
"authenticated": true,
"csrf": true,
"nonce": true,
"code_verifier": true,
@@ -164,7 +164,7 @@ func decompressCombinedPayload(compressed string) (*combinedSessionPayload, erro
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gr.Close()
defer func() { _ = gr.Close() }()
// Limit decompressed size to prevent zip bombs
limitedReader := io.LimitReader(gr, 512*1024) // 512KB max
@@ -500,6 +500,11 @@ func (sm *SessionManager) combinedChunkCookieName(chunkIndex int) string {
return fmt.Sprintf("%s_%d", sm.combinedCookieName(), chunkIndex)
}
// GetCookiePrefix returns the cookie prefix used for all OIDC session cookies.
func (sm *SessionManager) GetCookiePrefix() string {
return sm.cookiePrefix
}
// Shutdown gracefully shuts down the SessionManager and all its background tasks
func (sm *SessionManager) Shutdown() error {
var shutdownErr error
@@ -1129,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
// Populate legacy session values from combined payload
sessionData.mainSession.Values["email"] = payload.E
sessionData.mainSession.Values["user_identifier"] = payload.Ui
sessionData.mainSession.Values["authenticated"] = payload.Au
sessionData.mainSession.Values["csrf"] = payload.Cs
sessionData.mainSession.Values["nonce"] = payload.N
@@ -1211,6 +1216,18 @@ type SessionData struct {
dirty bool
inUse bool
// cachedClaimsToken is the ID token string whose claims were last parsed and
// cached. A lazy, per-request cache to avoid re-parsing the JWT on every
// authenticated request (e.g. for headerTemplates). Protected by sessionMutex.
cachedClaimsToken string
// cachedClaims holds the parsed claims for cachedClaimsToken.
cachedClaims map[string]interface{}
// cachedClaimsErr holds the parse error (if any) for cachedClaimsToken so
// failures are not retried within the same request.
cachedClaimsErr error
}
// IsDirty returns true if the session data has been modified since it was last loaded or saved.
@@ -1261,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti
A: sd.getAccessTokenUnsafe(),
R: sd.getRefreshTokenUnsafe(),
I: sd.getIDTokenUnsafe(),
E: sd.getEmailUnsafe(),
Ui: sd.getUserIdentifierUnsafe(),
Au: sd.getAuthenticatedUnsafe(),
Cs: sd.getCSRFUnsafe(),
N: sd.getNonceUnsafe(),
@@ -1548,9 +1565,10 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
}()
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
sd.clearAllSessionData(r, true)
// Release the lock before calling Save to prevent deadlock
sd.sessionMutex.Unlock()
// This is primarily for testing - in production w will often be nil
var err error
@@ -1588,7 +1606,7 @@ func (sd *SessionData) returnToPoolSafely() {
// Parameters:
// - r: The HTTP request context.
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
func (sd *SessionData) clearTokenChunks(_ *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
clearSessionValues(session, true)
}
@@ -1731,6 +1749,12 @@ func (sd *SessionData) Reset() {
sd.request = nil
sd.useCombinedStorage = true // Reset to use combined storage by default
// Drop any cached claims so pooled SessionData does not leak claim data
// between requests/users.
sd.cachedClaimsToken = ""
sd.cachedClaims = nil
sd.cachedClaimsErr = nil
// Reset the refresh mutex to ensure clean state
// Note: We don't need to lock it since sessionMutex is already held
// and this session is not in use by any request
@@ -1820,23 +1844,12 @@ func (sd *SessionData) SetAccessToken(token string) {
defer sd.sessionMutex.Unlock()
if token != "" {
dotCount := strings.Count(token, ".")
// Reject tokens with exactly 1 dot (invalid format - neither JWT nor opaque)
if dotCount == 1 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Invalid token format during storage (dots: %d) - rejecting", dotCount)
}
return
}
// For opaque tokens (no dots), ensure minimum length for security
if dotCount == 0 && len(token) < 20 {
if len(token) < 20 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token))
}
return
}
// Tokens with 2 dots are JWTs, tokens with 0 dots are opaque
// Both are valid formats
}
currentAccessToken := sd.getAccessTokenUnsafe()
@@ -2456,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
}
}
// GetEmail retrieves the authenticated user's email address.
// The email is extracted from ID token claims and used for
// authorization decisions and header injection.
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
// upn, preferred_username, etc.). The value is used for authorization
// decisions and header injection.
// Returns:
// - The user's email address string, or an empty string if not set.
func (sd *SessionData) GetEmail() string {
// - The user identifier string, or an empty string if not set.
func (sd *SessionData) GetUserIdentifier() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
email, _ := sd.mainSession.Values["email"].(string)
return email
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
return userIdentifier
}
// SetEmail stores the authenticated user's email address.
// The email is typically extracted from the 'email' claim in the ID token.
// SetUserIdentifier stores the authenticated user's identifier value.
// Parameters:
// - email: The user's email address to store.
func (sd *SessionData) SetEmail(email string) {
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
currentVal, _ := sd.mainSession.Values["email"].(string)
if currentVal != email {
sd.mainSession.Values["email"] = email
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
if currentVal != userIdentifier {
sd.mainSession.Values["user_identifier"] = userIdentifier
sd.dirty = true
}
}
@@ -2519,6 +2532,41 @@ func (sd *SessionData) GetIDToken() string {
return sd.getIDTokenUnsafe()
}
// GetIDTokenClaims returns claims parsed from the current ID token, caching
// the result on the SessionData so repeated callers within the same request
// do not re-parse the JWT. The cache is keyed on the ID token string and is
// cleared when the SessionData is reset (see Reset) or when the ID token
// changes (e.g. after a refresh).
//
// The parser parameter is typically the TraefikOidc.extractClaimsFunc, which
// lets tests inject mocks just like the direct call it replaces.
//
// Returns an empty claims map and a nil error when the session has no ID
// token, matching the existing "no-op" behavior of the caller sites.
func (sd *SessionData) GetIDTokenClaims(parser func(string) (map[string]interface{}, error)) (map[string]interface{}, error) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
token := sd.getIDTokenUnsafe()
if token == "" {
// Invalidate any stale cache without running the parser.
sd.cachedClaimsToken = ""
sd.cachedClaims = nil
sd.cachedClaimsErr = nil
return nil, nil
}
if sd.cachedClaimsToken == token && (sd.cachedClaims != nil || sd.cachedClaimsErr != nil) {
return sd.cachedClaims, sd.cachedClaimsErr
}
claims, err := parser(token)
sd.cachedClaimsToken = token
sd.cachedClaims = claims
sd.cachedClaimsErr = err
return claims, err
}
// getIDTokenUnsafe retrieves the ID token without acquiring locks.
// Enhanced ID token retrieval with comprehensive integrity checks and chunking support.
// Used when the session mutex is already held to prevent deadlocks.
@@ -2578,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string {
return result.Token
}
// getEmailUnsafe retrieves the email without acquiring locks.
func (sd *SessionData) getEmailUnsafe() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
func (sd *SessionData) getUserIdentifierUnsafe() string {
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
return userIdentifier
}
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
+6 -7
View File
@@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() {
s.False(session.IsDirty())
}
// TestSessionData_SetEmail tests email setter with dirty tracking
func (s *SessionBehaviourSuite) TestSessionData_SetEmail() {
// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking
func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
session, err := s.sessionManager.GetSession(req)
s.Require().NoError(err)
defer session.returnToPoolSafely()
// Set email
session.SetEmail("test@example.com")
s.Equal("test@example.com", session.GetEmail())
session.SetUserIdentifier("test@example.com")
s.Equal("test@example.com", session.GetUserIdentifier())
s.True(session.IsDirty())
}
@@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() {
// Set some data
err = session.SetAuthenticated(true)
s.Require().NoError(err)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.SetCSRF("csrf-token")
// Clear session
@@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() {
defer session.returnToPoolSafely()
// Modify session
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
s.True(session.IsDirty())
// Save session
+2
View File
@@ -926,6 +926,8 @@ func (cm *ChunkManager) detectRepeatedCharacters(token string, config TokenConfi
//
// Returns:
// - An error if the token is expired or has invalid expiration, nil if valid.
//
//nolint:unparam // error return kept for API consistency and future use
func (cm *ChunkManager) validateTokenExpiration(token string, config TokenConfig) error {
if !strings.Contains(token, ".") {
return nil
+6 -6
View File
@@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Set up initial session state (what user has when first logging in)
session1.SetAuthenticated(true)
session1.SetEmail(originalUserData["email"].(string))
session1.SetUserIdentifier(originalUserData["email"].(string))
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
@@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Simulate what happens when middleware detects expired tokens
// It should preserve session state while attempting token refresh
originalAuth := session2.GetAuthenticated()
originalEmail := session2.GetEmail()
originalEmail := session2.GetUserIdentifier()
// Reconstruct user data from individual stored keys
originalUserDataStored := make(map[string]interface{})
@@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Verify all session data is still intact after token refresh
postRefreshAuth := session2.GetAuthenticated()
postRefreshEmail := session2.GetEmail()
postRefreshEmail := session2.GetUserIdentifier()
userDataPresent := true
for k := range originalUserData {
if session2.mainSession.Values["user_data_"+k] == nil {
@@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
// Set up session with specific creation time
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
// Create tokens with specific expiry
@@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
// Set up session with data that should be preserved or removed
session.SetAuthenticated(true)
session.SetEmail("cleanup@example.com")
session.SetUserIdentifier("cleanup@example.com")
session.mainSession.Values["user_data"] = "Test User|user-123"
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
@@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
if scenario.shouldCleanup {
if sessionTooOld {
session.SetAuthenticated(false)
session.SetEmail("")
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
for key := range session.mainSession.Values {
+81 -116
View File
@@ -1,6 +1,7 @@
package traefikoidc
import (
"crypto/x509"
"fmt"
"io"
"log"
@@ -54,6 +55,15 @@ type Config struct {
AllowedUsers []string `json:"allowedUsers"`
Headers []TemplatedHeader `json:"headers"`
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
// a stored refresh token. Once the token has been in the session longer
// than this, requests treat it as expired up-front - returning 401 to
// AJAX callers and triggering full re-auth on navigations - instead of
// hammering the IdP with grants that will only fail with invalid_grant.
// IdPs do not expose RT TTL on the wire, so this is intentionally a
// conservative heuristic; tune to match your provider configuration.
// Default 21600 (6h). Set to 0 to disable the check.
MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"`
SessionMaxAge int `json:"sessionMaxAge"`
RateLimit int `json:"rateLimit"`
OverrideScopes bool `json:"overrideScopes"`
@@ -65,6 +75,51 @@ type Config struct {
ForceHTTPS bool `json:"forceHTTPS"`
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
StripAuthCookies bool `json:"stripAuthCookies,omitempty"`
EnableBackchannelLogout bool `json:"enableBackchannelLogout,omitempty"`
EnableFrontchannelLogout bool `json:"enableFrontchannelLogout,omitempty"`
BackchannelLogoutURL string `json:"backchannelLogoutURL,omitempty"`
FrontchannelLogoutURL string `json:"frontchannelLogoutURL,omitempty"`
// CACertPath is an optional filesystem path to a PEM-encoded CA bundle used
// to verify the OIDC provider's TLS certificate. Use this when the provider
// is signed by an internal/private CA that is not in the system trust store.
CACertPath string `json:"caCertPath,omitempty"`
// CACertPEM is an optional inline PEM-encoded CA bundle, equivalent to
// CACertPath but supplied directly in the middleware configuration. Both
// may be set; certificates from both sources are combined.
CACertPEM string `json:"caCertPEM,omitempty"`
// InsecureSkipVerify disables TLS certificate verification for the OIDC
// provider. Intended ONLY for local development against self-signed
// providers. Enabling this in production is a security hole — prefer
// CACertPath/CACertPEM. Emits a loud warning at startup.
InsecureSkipVerify bool `json:"insecureSkipVerify,omitempty"`
}
// loadCACertPool assembles an x509.CertPool from CACertPath and CACertPEM.
// Returns (nil, nil) when neither is configured — callers should fall back to
// the system trust store. Returns a descriptive error if a PEM source is
// configured but contains no parseable certificates, so misconfigurations
// surface at startup rather than as unexplained TLS failures at runtime.
func (c *Config) loadCACertPool() (*x509.CertPool, error) {
if c.CACertPath == "" && c.CACertPEM == "" {
return nil, nil
}
pool := x509.NewCertPool()
if c.CACertPath != "" {
data, err := os.ReadFile(c.CACertPath)
if err != nil {
return nil, fmt.Errorf("read caCertPath %q: %w", c.CACertPath, err)
}
if !pool.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("caCertPath %q: no valid PEM certificates found", c.CACertPath)
}
}
if c.CACertPEM != "" {
if !pool.AppendCertsFromPEM([]byte(c.CACertPEM)) {
return nil, fmt.Errorf("caCertPEM: no valid PEM certificates found")
}
}
return pool, nil
}
// RedisConfig configures Redis cache backend settings for distributed caching.
@@ -98,8 +153,15 @@ type DynamicClientRegistrationConfig struct {
InitialAccessToken string `json:"initialAccessToken,omitempty"`
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
CredentialsFile string `json:"credentialsFile,omitempty"`
Enabled bool `json:"enabled"`
PersistCredentials bool `json:"persistCredentials"`
// StorageBackend specifies where to store DCR credentials: "file", "redis", or "auto"
// - "file": Use file-based storage (default for backward compatibility)
// - "redis": Use Redis exclusively (fails if Redis unavailable)
// - "auto": Use Redis if available, fallback to file (default)
StorageBackend string `json:"storageBackend,omitempty"`
// RedisKeyPrefix is the prefix for Redis keys when using Redis storage (default: "dcr:creds:")
RedisKeyPrefix string `json:"redisKeyPrefix,omitempty"`
Enabled bool `json:"enabled"`
PersistCredentials bool `json:"persistCredentials"`
}
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
@@ -194,6 +256,7 @@ func CreateConfig() *Config {
EnablePKCE: false, // PKCE is opt-in
OverrideScopes: false, // Default to appending scopes, not overriding
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
MaxRefreshTokenAgeSeconds: 21600, // 6h - conservative heuristic, see field doc
SecurityHeaders: createDefaultSecurityConfig(),
Redis: nil, // Redis is disabled by default, configure via Traefik or env vars
}
@@ -317,6 +380,11 @@ func (c *Config) Validate() error {
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
}
// Validate refresh-token max-age heuristic
if c.MaxRefreshTokenAgeSeconds < 0 {
return fmt.Errorf("maxRefreshTokenAgeSeconds cannot be negative")
}
// Validate audience if specified
if c.Audience != "" {
// Validate audience format - should be a valid identifier or URL
@@ -722,7 +790,18 @@ func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// IsDebug reports whether debug-level logging is enabled.
// Callers should use this to avoid expensive format-string expansion
// (e.g. on hot paths under yaegi) when debug output would be discarded.
func (l *Logger) IsDebug() bool {
if l == nil || l.logDebug == nil {
return false
}
return l.logDebug.Writer() != io.Discard
}
// newNoOpLogger creates a logger that discards all output.
//
// Deprecated: Use GetSingletonNoOpLogger() instead for better memory efficiency.
func newNoOpLogger() *Logger {
return GetSingletonNoOpLogger()
@@ -737,15 +816,6 @@ func newNoOpLogger() *Logger {
// - code: The HTTP status code for the response.
// - logger: The Logger instance to use for logging the error.
//
// handleError writes an HTTP error response with the specified status code and message.
// It logs the error and sets appropriate headers before writing the response.
//
//lint:ignore U1000 Kept for potential future error handling
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error("%s", message)
http.Error(w, message, code)
}
// GetSecurityHeadersApplier returns a function that applies security headers
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
@@ -1051,111 +1121,6 @@ func (rc *RedisConfig) ApplyEnvFallbacks() {
}
}
// LoadRedisConfigFromEnv loads Redis configuration from environment variables.
// Deprecated: Use RedisConfig.ApplyEnvFallbacks() on an existing config instead.
// This function is kept for backward compatibility but should not be used directly.
func LoadRedisConfigFromEnv() *RedisConfig {
// Check if Redis is enabled
enabledStr := os.Getenv("REDIS_ENABLED")
if enabledStr == "" || enabledStr == "false" || enabledStr == "0" {
return nil
}
config := &RedisConfig{
Enabled: true,
}
// Parse numeric values
if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
if db, err := strconv.Atoi(dbStr); err == nil {
config.DB = db
}
}
if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
if poolSize, err := strconv.Atoi(poolSizeStr); err == nil {
config.PoolSize = poolSize
}
}
if connectTimeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); connectTimeoutStr != "" {
if timeout, err := strconv.Atoi(connectTimeoutStr); err == nil {
config.ConnectTimeout = timeout
}
}
if readTimeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); readTimeoutStr != "" {
if timeout, err := strconv.Atoi(readTimeoutStr); err == nil {
config.ReadTimeout = timeout
}
}
if writeTimeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeoutStr != "" {
if timeout, err := strconv.Atoi(writeTimeoutStr); err == nil {
config.WriteTimeout = timeout
}
}
// Parse boolean values
if enableTLSStr := os.Getenv("REDIS_ENABLE_TLS"); enableTLSStr == "true" || enableTLSStr == "1" {
config.EnableTLS = true
}
if skipVerifyStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipVerifyStr == "true" || skipVerifyStr == "1" {
config.TLSSkipVerify = true
}
// Parse hybrid mode settings
if l1SizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); l1SizeStr != "" {
if size, err := strconv.Atoi(l1SizeStr); err == nil {
config.HybridL1Size = size
}
}
if l1MemoryStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); l1MemoryStr != "" {
if memory, err := strconv.ParseInt(l1MemoryStr, 10, 64); err == nil {
config.HybridL1MemoryMB = memory
}
}
// Parse circuit breaker settings
if enableCBStr := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCBStr == "false" || enableCBStr == "0" {
config.EnableCircuitBreaker = false
} else {
config.EnableCircuitBreaker = true // Default to enabled
}
if cbThresholdStr := os.Getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD"); cbThresholdStr != "" {
if threshold, err := strconv.Atoi(cbThresholdStr); err == nil {
config.CircuitBreakerThreshold = threshold
}
}
if cbTimeoutStr := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeoutStr != "" {
if timeout, err := strconv.Atoi(cbTimeoutStr); err == nil {
config.CircuitBreakerTimeout = timeout
}
}
// Parse health check settings
if enableHCStr := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHCStr == "false" || enableHCStr == "0" {
config.EnableHealthCheck = false
} else {
config.EnableHealthCheck = true // Default to enabled
}
if hcIntervalStr := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcIntervalStr != "" {
if interval, err := strconv.Atoi(hcIntervalStr); err == nil {
config.HealthCheckInterval = interval
}
}
// Apply defaults after loading from env
config.ApplyDefaults()
return config
}
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if origin == allowed || allowed == "*" {
+13 -6
View File
@@ -548,17 +548,24 @@ func (gc *GenericCache) Delete(key string) {
delete(gc.data, key)
}
// cleanupRoutine periodically cleans up the cache
// cleanupRoutine periodically wipes the cache.
//
// NOTE: GenericCache does not track per-entry timestamps, so this is a
// "clear-all on tick" strategy — every `gc.ttl` interval the entire map
// is replaced, regardless of when each entry was written. This is the
// intentional (simplified) behavior of GenericCache, which exists mainly
// as a generic fallback for tests and non-typed caches. Callers that
// require true per-entry TTL must use UniversalCache / UnifiedCache which
// track expiry per entry.
func (gc *GenericCache) cleanupRoutine() {
ticker := time.NewTicker(gc.ttl)
defer ticker.Stop()
wipeTicker := time.NewTicker(gc.ttl)
defer wipeTicker.Stop()
for {
select {
case <-ticker.C:
case <-wipeTicker.C:
gc.mu.Lock()
// Simple cleanup - clear all data after TTL
// In production, you'd track individual entry TTLs
// Clear-all on tick, not per-entry TTL (see function doc).
gc.data = make(map[string]interface{})
gc.mu.Unlock()
case <-gc.stopChan:
+51 -5
View File
@@ -4,7 +4,10 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"sync"
"sync/atomic"
@@ -251,6 +254,30 @@ func TestSingletonResourceManager(t *testing.T) {
})
}
// createMockOIDCServer creates a mock OIDC server for testing
func createMockOIDCServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/authorize",
"token_endpoint": "https://example.com/token",
"jwks_uri": "https://example.com/jwks",
"userinfo_endpoint": "https://example.com/userinfo",
})
case "/jwks":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"keys": []interface{}{},
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
}
// TestContextAwareGoroutineManagement tests context-aware goroutine management
func TestContextAwareGoroutineManagement(t *testing.T) {
t.Run("GoroutineCleanupOnContextCancel", func(t *testing.T) {
@@ -259,13 +286,17 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC server
mockServer := createMockOIDCServer()
defer mockServer.Close()
initialGoroutines := runtime.NumGoroutine()
ctx, cancel := context.WithCancel(context.Background())
// Create a TraefikOidc instance with context
config := &Config{
ProviderURL: "https://example.com",
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
}
@@ -308,12 +339,20 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC servers
mockServer1 := createMockOIDCServer()
defer mockServer1.Close()
mockServer2 := createMockOIDCServer()
defer mockServer2.Close()
mockServer3 := createMockOIDCServer()
defer mockServer3.Close()
initialGoroutines := runtime.NumGoroutine()
configs := []Config{
{ProviderURL: "https://example1.com", ClientID: "client1", ClientSecret: "secret1"},
{ProviderURL: "https://example2.com", ClientID: "client2", ClientSecret: "secret2"},
{ProviderURL: "https://example3.com", ClientID: "client3", ClientSecret: "secret3"},
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
}
var plugins []*TraefikOidc
@@ -366,6 +405,13 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC servers
mockServers := make([]*httptest.Server, 3)
for i := 0; i < 3; i++ {
mockServers[i] = createMockOIDCServer()
defer mockServers[i].Close()
}
rm := GetResourceManager()
// Register singleton cleanup task
@@ -386,7 +432,7 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
for i := 0; i < 3; i++ {
ctx := context.Background()
config := &Config{
ProviderURL: fmt.Sprintf("https://example%d.com", i),
ProviderURL: mockServers[i].URL,
ClientID: fmt.Sprintf("client%d", i),
ClientSecret: fmt.Sprintf("secret%d", i),
}
+1 -1
View File
@@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.
}
session.SetAuthenticated(true)
session.SetEmail(tf.fixtures.UserEmail)
session.SetUserIdentifier(tf.fixtures.UserEmail)
session.SetAccessToken(tf.fixtures.AccessToken)
session.SetRefreshToken(tf.fixtures.RefreshToken)
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
+5 -5
View File
@@ -22,7 +22,7 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
// Test helper adapters for the new test files
// resetGlobalState resets all global singletons to prevent test interference
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func resetGlobalState() {
// Reset global task registry first to stop all background tasks
@@ -137,7 +137,7 @@ func (tc *testCleanup) cleanupAll() {
}
// createTestConfig creates a config with all required fields populated for testing
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func createTestConfig() *Config {
config := CreateConfig()
@@ -151,7 +151,7 @@ func createTestConfig() *Config {
*/
// setupTestOIDCMiddleware creates a test OIDC middleware instance with mock servers
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httptest.Server) {
// Reset global state to ensure test isolation
@@ -339,7 +339,7 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
*/
// createMockJWT creates a mock JWT token for testing - adapter for existing tests
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func createMockJWT(t *testing.T, sub, email string) string {
return ValidIDToken
@@ -361,7 +361,7 @@ func createTestSession() *SessionData {
}
// injectSessionIntoRequest saves the session and adds the resulting cookies to the request
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func injectSessionIntoRequest(t *testing.T, req *http.Request, session *SessionData) {
// Create a response recorder to capture cookies
+132 -65
View File
@@ -11,6 +11,7 @@ import (
"io"
"net/http"
"net/url"
"runtime"
"strings"
"time"
)
@@ -46,6 +47,17 @@ func (t *TraefikOidc) VerifyToken(token string) error {
}
}
// Hot-path fast-return: a previously-verified token has already passed
// signature, claims, and replay checks. Skipping the parseJWT cost here
// matters under bursty traffic (e.g. 10+ concurrent panel requests on
// every Grafana dashboard refresh) where the same token is validated
// dozens of times per second by validateStandardTokens.
if t.tokenCache != nil {
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
return nil
}
}
parsedJWT, parseErr := parseJWT(token)
if parseErr != nil {
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
@@ -63,12 +75,6 @@ func (t *TraefikOidc) VerifyToken(token string) error {
}
}
// Check token cache FIRST - if token is already verified and cached, return immediately
// This prevents false positives when multiple goroutines validate the same token concurrently
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
return nil
}
// Only check JTI blacklist for tokens that aren't already in the cache
// This is for FIRST-TIME validation to detect replay attacks
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
@@ -315,15 +321,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
jwksURL := t.jwksURL
t.metadataMu.RUnlock()
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
if !t.suppressDiagnosticLogs && jwks != nil {
t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), jwksURL)
}
kid, ok := jwt.Header["kid"].(string)
if !ok {
return fmt.Errorf("missing key ID in token header")
@@ -337,38 +334,12 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
t.safeLogDebugf("DIAGNOSTIC: Looking for kid=%s, alg=%s in JWKS", kid, alg)
}
if jwks == nil {
return fmt.Errorf("JWKS is nil, cannot verify token")
}
// Find the matching key in JWKS
var matchingKey *JWK
availableKids := make([]string, 0, len(jwks.Keys))
for _, key := range jwks.Keys {
availableKids = append(availableKids, key.Kid)
if key.Kid == kid {
matchingKey = &key
break
}
}
if matchingKey == nil {
if !t.suppressDiagnosticLogs {
t.safeLogErrorf("DIAGNOSTIC: No matching key found for kid=%s. Available kids: %v", kid, availableKids)
}
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
if !t.suppressDiagnosticLogs {
t.safeLogDebugf("DIAGNOSTIC: Found matching key for kid=%s, key type: %s", kid, matchingKey.Kty)
}
publicKeyPEM, err := jwkToPEM(matchingKey)
pubKey, err := t.jwkCache.GetPublicKey(context.Background(), jwksURL, kid, t.httpClient)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
return fmt.Errorf("failed to get public key: %w", err)
}
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
if err := verifySignatureWithKey(token, pubKey, alg); err != nil {
if !t.suppressDiagnosticLogs {
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
}
@@ -451,10 +422,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
}
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken)
if err != nil {
errMsg := err.Error()
//nolint:gocritic // Complex error handling with provider-specific conditions
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
t.logger.Debug("Refresh token expired or revoked: %v", err)
// Clear all tokens and authentication state when refresh token is invalid
@@ -464,7 +434,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
session.SetRefreshToken("")
session.SetAccessToken("")
session.SetIDToken("")
session.SetEmail("")
session.SetUserIdentifier("")
// Clear CSRF tokens as well to prevent any replay attacks
session.SetCSRF("")
session.SetNonce("")
@@ -506,12 +476,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
return false
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
return false
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
if userIdentifier == "" {
if t.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim)
return false
}
t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim)
}
session.SetEmail(email)
session.SetUserIdentifier(userIdentifier)
// Get token expiry information for logging
var expiryTime time.Time
@@ -537,7 +513,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
session.SetAccessToken("")
session.SetIDToken("")
session.SetRefreshToken("")
session.SetEmail("")
session.SetUserIdentifier("")
return false
}
@@ -554,6 +530,91 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return true
}
// coordinatedTokenRefresh routes a refresh-token grant through the
// RefreshCoordinator so that concurrent requests sharing the same refresh
// token coalesce into a single upstream call. This prevents the thundering
// herd that yields invalid_grant when the IdP rotates refresh tokens.
//
// Falls back to a direct call when the coordinator is nil, which only
// happens in tests that build TraefikOidc literals without going through
// NewWithContext.
func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) {
if t.refreshCoordinator == nil {
return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
}
parentCtx := context.Background()
if req != nil {
parentCtx = req.Context()
}
ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout)
defer cancel()
sessionID := refreshCoordinatorSessionID(refreshToken)
return t.refreshCoordinator.CoordinateRefresh(
ctx,
sessionID,
refreshToken,
func() (*TokenResponse, error) {
// Cross-replica dedup. The in-process coordinator already
// collapses concurrent grants on this pod; this Redis-backed
// short-TTL cache covers the (rare) case of a failover or
// load-balancer reroute mid-refresh, where two pods would
// otherwise both POST the same refresh_token to the IdP.
if cached, ok := t.lookupCachedRefreshResult(sessionID); ok {
return cached, nil
}
resp, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
if err == nil && resp != nil {
t.cacheRefreshResult(sessionID, resp)
}
return resp, err
},
)
}
// lookupCachedRefreshResult returns a previously-stored TokenResponse for the
// given refresh-token hash, if one exists and is still within its short TTL.
// The cache wraps the universal cache, which is Redis-backed in production -
// so a "hit" here means another Traefik replica refreshed this same token
// within the last few seconds.
func (t *TraefikOidc) lookupCachedRefreshResult(sessionID string) (*TokenResponse, bool) {
if t.refreshResultCache == nil {
return nil, false
}
v, ok := t.refreshResultCache.Get(refreshResultCacheKey(sessionID))
if !ok || v == nil {
return nil, false
}
if tr, ok := v.(*TokenResponse); ok && tr != nil {
return tr, true
}
return nil, false
}
// cacheRefreshResult stores the new TokenResponse under the refresh-token
// hash for a short window. TTL is intentionally tight: the rotated refresh
// token cannot be re-presented to the IdP, and any peer waiting longer than
// this window has almost certainly given up via its own coordinator timeout.
func (t *TraefikOidc) cacheRefreshResult(sessionID string, resp *TokenResponse) {
if t.refreshResultCache == nil || resp == nil {
return
}
t.refreshResultCache.Set(refreshResultCacheKey(sessionID), resp, refreshResultCacheTTL)
}
// refreshResultCacheKey namespaces refresh-result entries inside the shared
// cache namespace.
func refreshResultCacheKey(sessionID string) string {
return "rt-result:" + sessionID
}
// refreshResultCacheTTL bounds how long a peer can lean on the dedup cache.
// Long enough for a sibling replica to observe the result, short enough that
// a stale entry never re-supplies a token after the IdP has already moved on.
const refreshResultCacheTTL = 5 * time.Second
// RevokeToken revokes a token locally by adding it to the blacklist cache.
// It removes the token from the verification cache and adds both the token
// and its JTI (if present) to the blacklist to prevent future use.
@@ -1139,9 +1200,14 @@ func (t *TraefikOidc) startTokenCleanup() {
sessionManager := t.sessionManager
logger := t.logger
// Only use the fast cleanup interval when actually running under `go test`.
// runtime.Compiler == "yaegi" makes isTestMode() return true in production
// (Traefik interprets the plugin via yaegi), which would otherwise pin this
// ticker to 20 Hz on a real cluster despite tokenCache.Cleanup and
// jwkCache.Cleanup both being no-ops there.
cleanupInterval := 1 * time.Minute
if isTestMode() {
cleanupInterval = 50 * time.Millisecond // Fast interval for tests
if isTestMode() && runtime.Compiler != "yaegi" {
cleanupInterval = 50 * time.Millisecond
}
// Create cleanup function
@@ -1183,25 +1249,27 @@ func (t *TraefikOidc) startTokenCleanup() {
}
// extractGroupsAndRoles extracts group and role information from token claims.
// It parses the 'groups' and 'roles' claims from the ID token and validates their format.
// Parameters:
// - idToken: The ID token containing claims to extract.
// It parses the configured group/role claims from the supplied ID token.
//
// Returns:
// - groups: Array of group names from the 'groups' claim.
// - roles: Array of role names from the 'roles' claim.
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present
// but not arrays of strings.
// Most callers should prefer extractGroupsAndRolesFromClaims when claims have
// already been parsed for the request (e.g. via SessionData.GetIDTokenClaims),
// to avoid re-parsing the JWT.
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
}
return t.extractGroupsAndRolesFromClaims(claims)
}
// extractGroupsAndRolesFromClaims extracts group and role information from
// already-parsed claims. Hot path: callers that have a cached claims map (such
// as SessionData.GetIDTokenClaims) should use this to skip a redundant
// base64+JSON decode of the JWT on every authenticated request.
func (t *TraefikOidc) extractGroupsAndRolesFromClaims(claims map[string]interface{}) ([]string, []string, error) {
var groups []string
var roles []string
// Extract groups using configurable claim name (defaults to "groups")
if groupsClaim, exists := claims[t.groupClaimName]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
@@ -1217,7 +1285,6 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
}
}
// Extract roles using configurable claim name (defaults to "roles")
if rolesClaim, exists := claims[t.roleClaimName]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
+9
View File
@@ -95,6 +95,7 @@ type TraefikOidc struct {
cancelFunc context.CancelFunc
errorRecoveryManager *ErrorRecoveryManager
tokenResilienceManager *TokenResilienceManager
refreshCoordinator *RefreshCoordinator
goroutineWG *sync.WaitGroup
dcrConfig *DynamicClientRegistrationConfig
dynamicClientRegistrar *DynamicClientRegistrar
@@ -119,14 +120,22 @@ type TraefikOidc struct {
clientID string
clientSecret string
registrationURL string
backchannelLogoutPath string
frontchannelLogoutPath string
scopesSupported []string
scopes []string
refreshGracePeriod time.Duration
maxRefreshTokenAge time.Duration
metadataMu sync.RWMutex
shutdownOnce sync.Once
metadataRetryMutex sync.Mutex
firstRequestMutex sync.Mutex
sessionInvalidationCache CacheInterface
refreshResultCache CacheInterface
minimalHeaders bool
stripAuthCookies bool
enableBackchannelLogout bool
enableFrontchannelLogout bool
firstRequestReceived bool
requireTokenIntrospection bool
metadataRefreshStarted bool
+145 -28
View File
@@ -21,6 +21,10 @@ const (
CacheTypeJWK CacheType = "jwk"
CacheTypeSession CacheType = "session"
CacheTypeGeneral CacheType = "general"
// maxCacheEntrySize defines the maximum size for a single cache entry (64 MiB)
// This prevents integer overflow when allocating memory for serialization
maxCacheEntrySize = 64 * 1024 * 1024
)
// UniversalCacheConfig provides configuration for the universal cache
@@ -248,6 +252,25 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
}
}
return c.setLocal(key, value, ttl)
}
// SetLocal stores a value only in the in-memory LRU, bypassing any
// distributed backend. Use for values that don't survive JSON round-tripping
// — interfaces holding concrete crypto keys, *big.Int, or types whose
// unexported fields yaegi exposes under an X prefix on Marshal. Each replica
// caches independently; correctness must not depend on cross-replica
// coherence for these keys.
func (c *UniversalCache) SetLocal(key string, value interface{}, ttl time.Duration) error {
if ttl == 0 {
ttl = c.config.DefaultTTL
}
return c.setLocal(key, value, ttl)
}
// setLocal performs the in-memory portion of a write. ttl must already be
// resolved against DefaultTTL by the caller.
func (c *UniversalCache) setLocal(key string, value interface{}, ttl time.Duration) error {
size := c.estimateSize(value)
c.mu.Lock()
@@ -302,8 +325,10 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
c.currentMemory += size
}
c.logger.Debugf("UniversalCache[%s]: Set key=%s, ttl=%v, size=%d bytes",
c.config.Type, key, ttl, size)
if c.logger.IsDebug() {
c.logger.Debugf("UniversalCache[%s]: Set key=%s, ttl=%v, size=%d bytes",
c.config.Type, key, ttl, size)
}
return nil
}
@@ -327,15 +352,54 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) {
// Fall through to local cache
} else {
atomic.AddInt64(&c.hits, 1)
// Update local cache with backend value
go func() {
_ = c.updateLocalCache(key, value, c.config.DefaultTTL)
}()
// Update local cache with backend value synchronously.
// Under yaegi, goroutine spawn is 5-10x costlier than compiled Go,
// and this path fires per-request on cold local cache.
// updateLocalCache is cheap (map write under mutex).
_ = c.updateLocalCache(key, value, c.config.DefaultTTL)
return value, true
}
}
}
return c.getLocal(key)
}
// GetLocal retrieves a value only from the in-memory LRU, never querying the
// distributed backend. Pair with SetLocal for values that aren't safe to
// serialize (see SetLocal docstring).
func (c *UniversalCache) GetLocal(key string) (interface{}, bool) {
return c.getLocal(key)
}
// getLocal returns the in-memory entry for key honoring expiry, grace
// periods, and the RLock fast path used by token/JWK/session caches.
func (c *UniversalCache) getLocal(key string) (interface{}, bool) {
// Fast read path for caches whose eviction is dominated by TTL rather than
// access-recency (token, JWK, session). Holding only an RLock here lets all
// concurrent readers verify cached tokens in parallel — under yaegi the
// previous unconditional Lock serialized every JWT verify on a single
// mutex and pinned a CPU under load.
switch c.config.Type {
case CacheTypeToken, CacheTypeJWK, CacheTypeSession:
c.mu.RLock()
item, exists := c.items[key]
if !exists {
c.mu.RUnlock()
atomic.AddInt64(&c.misses, 1)
return nil, false
}
if !time.Now().After(item.ExpiresAt) {
value := item.Value
c.mu.RUnlock()
atomic.AddInt64(&c.hits, 1)
return value, true
}
c.mu.RUnlock()
// Expired — fall through to the write-locked slow path below to
// remove the entry under exclusive access.
}
c.mu.Lock()
defer c.mu.Unlock()
@@ -436,7 +500,7 @@ func (c *UniversalCache) Clear() {
c.currentSize = 0
c.currentMemory = 0
c.logger.Infof("UniversalCache[%s]: Cleared all items", c.config.Type)
c.logger.Debugf("UniversalCache[%s]: Cleared all items", c.config.Type)
}
// Size returns the number of items in the cache
@@ -536,7 +600,9 @@ func (c *UniversalCache) evictOldest() {
if item, exists := c.items[key]; exists {
c.removeItem(key, item)
atomic.AddInt64(&c.evictions, 1)
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
if c.logger.IsDebug() {
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
}
}
}
}
@@ -720,22 +786,6 @@ func (c *UniversalCache) SetWithMetadata(key string, value interface{}, ttl time
return nil
}
// GetTyped retrieves a typed value from the cache
func GetTyped[T any](c *UniversalCache, key string) (T, bool) {
var zero T
value, exists := c.Get(key)
if !exists {
return zero, false
}
typed, ok := value.(T)
if !ok {
return zero, false
}
return typed, true
}
// TokenCacheOperations provides token-specific operations
func (c *UniversalCache) BlacklistToken(token string, ttl time.Duration) error {
if c.config.Type != CacheTypeToken {
@@ -784,14 +834,81 @@ func (c *UniversalCache) Strategy() CacheStrategy {
// serialize converts a value to bytes for backend storage
func (c *UniversalCache) serialize(value interface{}) ([]byte, error) {
// Use JSON for serialization - simple and universal
return json.Marshal(value)
// If value is already a byte slice (e.g., pre-marshaled JSON from metadata_cache),
// store it directly with a marker to prevent double-encoding.
// This fixes the issue where []byte was being JSON-marshaled, causing Base64 encoding.
if bytes, ok := value.([]byte); ok {
// Validate size to prevent integer overflow
if len(bytes) > maxCacheEntrySize {
return nil, fmt.Errorf("cache entry size %d exceeds maximum allowed size %d", len(bytes), maxCacheEntrySize)
}
// Check for potential overflow when adding marker byte
if len(bytes) == maxCacheEntrySize {
return nil, fmt.Errorf("cache entry size would overflow when adding marker byte")
}
// Prepend marker byte 0x00 to indicate raw bytes (not JSON-encoded)
result := make([]byte, len(bytes)+1)
result[0] = 0x00
copy(result[1:], bytes)
return result, nil
}
// For all other types (maps, strings, etc.), use JSON encoding
// Prepend marker byte 0x01 to indicate JSON-encoded data
jsonData, err := json.Marshal(value)
if err != nil {
return nil, err
}
// Validate size to prevent integer overflow
if len(jsonData) > maxCacheEntrySize {
return nil, fmt.Errorf("serialized cache entry size %d exceeds maximum allowed size %d", len(jsonData), maxCacheEntrySize)
}
// Check for potential overflow when adding marker byte
if len(jsonData) == maxCacheEntrySize {
return nil, fmt.Errorf("serialized cache entry size would overflow when adding marker byte")
}
result := make([]byte, len(jsonData)+1)
result[0] = 0x01
copy(result[1:], jsonData)
return result, nil
}
// deserialize converts bytes from backend storage to a value
func (c *UniversalCache) deserialize(data []byte, value interface{}) error {
// Use JSON for deserialization
return json.Unmarshal(data, value)
if len(data) == 0 {
return fmt.Errorf("cannot deserialize empty data")
}
// Check for type marker (added by serialize)
if data[0] == 0x00 {
// Raw bytes - strip marker and return as-is
rawBytes := data[1:]
if ptr, ok := value.(*interface{}); ok {
*ptr = rawBytes
return nil
}
return fmt.Errorf("cannot deserialize raw bytes into %T", value)
}
if data[0] == 0x01 {
// JSON-encoded - strip marker and unmarshal
return json.Unmarshal(data[1:], value)
}
// Legacy data without marker (for backward compatibility)
// Try to unmarshal as JSON
if err := json.Unmarshal(data, value); err != nil {
// If unmarshal fails, treat as raw bytes
if ptr, ok := value.(*interface{}); ok {
*ptr = data
return nil
}
return err
}
return nil
}
// prefixKey adds a cache type prefix to the key for backend storage
+517
View File
@@ -0,0 +1,517 @@
package traefikoidc
import (
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestUniversalCache_SerializeDeserialize tests the fix for issue #116
// where metadata was stored as Base64-encoded JSON but read as plain JSON
func TestUniversalCache_SerializeDeserialize(t *testing.T) {
t.Parallel()
t.Run("RawBytesPreserved", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Test data: pre-marshaled JSON bytes (like metadata_cache uses)
testData := []byte(`{"issuer":"https://example.com","jwks_uri":"https://example.com/jwks"}`)
// Serialize
serialized, err := cache.serialize(testData)
require.NoError(t, err)
assert.NotNil(t, serialized)
// Should have marker byte
assert.Equal(t, byte(0x00), serialized[0], "Should have raw bytes marker")
assert.Equal(t, testData, serialized[1:], "Data should be preserved after marker")
// Deserialize
var result interface{}
err = cache.deserialize(serialized, &result)
require.NoError(t, err)
// Should get back []byte
resultBytes, ok := result.([]byte)
require.True(t, ok, "Result should be []byte")
assert.Equal(t, testData, resultBytes, "Deserialized data should match original")
})
t.Run("JSONEncodedTypes", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
testCases := []struct {
name string
value interface{}
}{
{
name: "Map",
value: map[string]interface{}{"key": "value", "number": 42.0},
},
{
name: "String",
value: "test-string",
},
{
name: "Number",
value: 123.456,
},
{
name: "Array",
value: []interface{}{"a", "b", "c"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Serialize
serialized, err := cache.serialize(tc.value)
require.NoError(t, err)
assert.NotNil(t, serialized)
// Should have JSON marker byte
assert.Equal(t, byte(0x01), serialized[0], "Should have JSON marker")
// Verify the JSON portion is valid
var checkJSON interface{}
err = json.Unmarshal(serialized[1:], &checkJSON)
require.NoError(t, err, "Should be valid JSON after marker")
// Deserialize
var result interface{}
err = cache.deserialize(serialized, &result)
require.NoError(t, err)
// Compare results (using JSON round-trip for consistent comparison)
expectedJSON, _ := json.Marshal(tc.value)
resultJSON, _ := json.Marshal(result)
assert.JSONEq(t, string(expectedJSON), string(resultJSON), "Deserialized data should match original")
})
}
})
t.Run("LegacyDataCompatibility", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Simulate legacy data (JSON without marker byte)
legacyData := []byte(`{"legacy":"data"}`)
var result interface{}
err := cache.deserialize(legacyData, &result)
require.NoError(t, err)
// Should successfully unmarshal as JSON
resultMap, ok := result.(map[string]interface{})
require.True(t, ok, "Should unmarshal legacy JSON data")
assert.Equal(t, "data", resultMap["legacy"])
})
t.Run("EmptyDataHandling", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
var result interface{}
err := cache.deserialize([]byte{}, &result)
assert.Error(t, err, "Should error on empty data")
assert.Contains(t, err.Error(), "empty data")
})
t.Run("OverflowProtection_LargeBytes", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Create a byte slice that exceeds maxCacheEntrySize (64 MiB)
oversizedBytes := make([]byte, 65*1024*1024) // 65 MiB
// Attempt to serialize - should fail with overflow error
_, err := cache.serialize(oversizedBytes)
require.Error(t, err, "Should error on oversized byte slice")
assert.Contains(t, err.Error(), "exceeds maximum allowed size")
})
t.Run("OverflowProtection_ExactMaxSize", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Create a byte slice exactly at maxCacheEntrySize
// This should fail because adding marker byte would overflow
exactMaxBytes := make([]byte, 64*1024*1024) // Exactly 64 MiB
_, err := cache.serialize(exactMaxBytes)
require.Error(t, err, "Should error when adding marker would overflow")
assert.Contains(t, err.Error(), "would overflow when adding marker byte")
})
t.Run("OverflowProtection_SafeSize", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Create a byte slice well within limits
safeBytes := make([]byte, 1024*1024) // 1 MiB - safe size
serialized, err := cache.serialize(safeBytes)
require.NoError(t, err, "Should succeed with safe size")
assert.NotNil(t, serialized)
assert.Equal(t, len(safeBytes)+1, len(serialized), "Should add marker byte")
})
t.Run("OverflowProtection_JSONData", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Create a very large map that will exceed limits when JSON-encoded
largeMap := make(map[string]string)
// Each entry is roughly 50 bytes, so we need ~1.3M entries to exceed 64 MiB
for i := 0; i < 1400000; i++ {
key := fmt.Sprintf("key_%d", i)
largeMap[key] = "value_with_some_content_to_make_it_larger"
}
_, err := cache.serialize(largeMap)
require.Error(t, err, "Should error when JSON serialization exceeds size limit")
assert.Contains(t, err.Error(), "exceeds maximum allowed size")
})
}
// TestUniversalCache_RedisIntegration_Issue116 tests the complete fix for issue #116
// with actual Redis backend to ensure metadata cache works correctly
func TestUniversalCache_RedisIntegration_Issue116(t *testing.T) {
t.Parallel()
// Start miniredis server
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
// Create Redis backend
redisConfig := backends.DefaultRedisConfig(mr.Addr())
redisConfig.RedisPrefix = "test:"
backend, err := backends.NewRedisBackend(redisConfig)
require.NoError(t, err)
defer backend.Close()
t.Run("MetadataCache_StoreAndRetrieve", func(t *testing.T) {
// Create cache with Redis backend
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeMetadata,
MaxSize: 100,
}, backend)
defer cache.Close()
// Simulate metadata_cache.Set behavior:
// 1. Marshal metadata to JSON
metadata := ProviderMetadata{
Issuer: "https://example.com",
JWKSURL: "https://example.com/jwks",
TokenURL: "https://example.com/token",
AuthURL: "https://example.com/authorize",
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
// 2. Store the JSON bytes
key := "v2:https://example.com"
err = cache.Set(key, jsonData, 1*time.Hour)
require.NoError(t, err)
// 3. Retrieve the data
retrieved, exists := cache.Get(key)
require.True(t, exists, "Data should exist in cache")
// 4. Should get back []byte (not a string or map)
retrievedBytes, ok := retrieved.([]byte)
require.True(t, ok, "Retrieved value should be []byte, got %T", retrieved)
// 5. Should be able to unmarshal as JSON
var retrievedMetadata ProviderMetadata
err = json.Unmarshal(retrievedBytes, &retrievedMetadata)
require.NoError(t, err, "Should be able to unmarshal retrieved bytes as JSON")
// 6. Verify data integrity
assert.Equal(t, metadata.Issuer, retrievedMetadata.Issuer)
assert.Equal(t, metadata.JWKSURL, retrievedMetadata.JWKSURL)
assert.Equal(t, metadata.TokenURL, retrievedMetadata.TokenURL)
})
t.Run("MetadataCache_NoBase64Encoding", func(t *testing.T) {
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeMetadata,
MaxSize: 100,
}, backend)
defer cache.Close()
// Store JSON bytes
jsonData := []byte(`{"issuer":"https://test.com"}`)
key := "v2:https://test.com"
err = cache.Set(key, jsonData, 1*time.Hour)
require.NoError(t, err)
// Retrieve
retrieved, exists := cache.Get(key)
require.True(t, exists)
retrievedBytes, ok := retrieved.([]byte)
require.True(t, ok)
// The retrieved data should NOT start with "eyJ" (Base64 encoding of "{")
// This was the bug in issue #116
assert.NotEqual(t, []byte("eyJ"), retrievedBytes[:3], "Data should not be Base64 encoded")
// Should be valid JSON
var checkJSON map[string]interface{}
err = json.Unmarshal(retrievedBytes, &checkJSON)
require.NoError(t, err, "Data should be valid JSON")
assert.Equal(t, "https://test.com", checkJSON["issuer"])
})
t.Run("TokenCache_MapValues", func(t *testing.T) {
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 100,
}, backend)
defer cache.Close()
// Store a map (like TokenCache does)
claims := map[string]interface{}{
"sub": "user123",
"exp": 1234567890.0,
"scope": "read write",
}
key := "token:abc123"
err = cache.Set(key, claims, 10*time.Minute)
require.NoError(t, err)
// Retrieve
retrieved, exists := cache.Get(key)
require.True(t, exists)
// Should get back a map
retrievedMap, ok := retrieved.(map[string]interface{})
require.True(t, ok, "Retrieved value should be map[string]interface{}")
assert.Equal(t, "user123", retrievedMap["sub"])
assert.Equal(t, 1234567890.0, retrievedMap["exp"])
})
t.Run("MixedTypes_SameCache", func(t *testing.T) {
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
}, backend)
defer cache.Close()
// Store different types
jsonBytes := []byte(`{"type":"json-bytes"}`)
err = cache.Set("key1", jsonBytes, 1*time.Hour)
require.NoError(t, err)
mapData := map[string]interface{}{"type": "map"}
err = cache.Set("key2", mapData, 1*time.Hour)
require.NoError(t, err)
stringData := "plain-string"
err = cache.Set("key3", stringData, 1*time.Hour)
require.NoError(t, err)
// Retrieve and verify each type
val1, exists := cache.Get("key1")
require.True(t, exists)
bytes1, ok := val1.([]byte)
require.True(t, ok)
assert.Equal(t, jsonBytes, bytes1)
val2, exists := cache.Get("key2")
require.True(t, exists)
map2, ok := val2.(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "map", map2["type"])
val3, exists := cache.Get("key3")
require.True(t, exists)
str3, ok := val3.(string)
require.True(t, ok)
assert.Equal(t, stringData, str3)
})
}
// TestUniversalCache_BackwardCompatibility tests that old cached data is handled gracefully
func TestUniversalCache_BackwardCompatibility(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisConfig := backends.DefaultRedisConfig(mr.Addr())
backend, err := backends.NewRedisBackend(redisConfig)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("LegacyJSONData", func(t *testing.T) {
// Manually insert legacy data (plain JSON without marker)
legacyKey := "general:legacy-key"
legacyData := []byte(`{"old":"format"}`)
err = backend.Set(ctx, legacyKey, legacyData, 1*time.Hour)
require.NoError(t, err)
// Try to retrieve via UniversalCache
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
}, backend)
defer cache.Close()
retrieved, exists := cache.Get("legacy-key")
require.True(t, exists, "Should retrieve legacy data")
// Should deserialize as JSON map
retrievedMap, ok := retrieved.(map[string]interface{})
require.True(t, ok, "Should unmarshal legacy JSON")
assert.Equal(t, "format", retrievedMap["old"])
})
t.Run("LegacyCorruptData", func(t *testing.T) {
// Insert corrupt/invalid data
corruptKey := "general:corrupt-key"
corruptData := []byte("not json and no marker")
err = backend.Set(ctx, corruptKey, corruptData, 1*time.Hour)
require.NoError(t, err)
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
}, backend)
defer cache.Close()
retrieved, exists := cache.Get("corrupt-key")
require.True(t, exists)
// Should return as raw bytes (fallback)
retrievedBytes, ok := retrieved.([]byte)
require.True(t, ok, "Should return corrupt data as raw bytes")
assert.Equal(t, corruptData, retrievedBytes)
})
}
// TestMetadataCache_Issue116_Regression is the main regression test for issue #116
// This specifically tests the scenario described in the GitHub issue
func TestMetadataCache_Issue116_Regression(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
// Create Redis backend
redisConfig := backends.DefaultRedisConfig(mr.Addr())
redisConfig.RedisPrefix = "traefik:"
backend, err := backends.NewRedisBackend(redisConfig)
require.NoError(t, err)
defer backend.Close()
// Create a simple logger
logger := GetSingletonNoOpLogger()
// Create metadata cache instance
metadataCache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeMetadata,
MaxSize: 100,
Logger: logger,
SkipAutoCleanup: true,
}, backend)
defer metadataCache.Close()
// Use the actual MetadataCache wrapper
wg := &sync.WaitGroup{}
mc := &MetadataCache{
cache: metadataCache,
logger: logger,
wg: wg,
}
// Test: Store and retrieve metadata (the scenario from issue #116)
providerURL := "https://example.com"
metadata := &ProviderMetadata{
Issuer: "https://example.com",
AuthURL: "https://example.com/authorize",
TokenURL: "https://example.com/token",
JWKSURL: "https://example.com/jwks",
RevokeURL: "https://example.com/revoke",
EndSessionURL: "https://example.com/logout",
RegistrationURL: "https://example.com/register",
ScopesSupported: []string{"openid", "profile", "email"},
}
// Store metadata
err = mc.Set(providerURL, metadata, 1*time.Hour)
require.NoError(t, err, "Should store metadata without error")
// Retrieve metadata
retrieved, exists := mc.Get(providerURL)
require.True(t, exists, "Should retrieve stored metadata")
require.NotNil(t, retrieved, "Retrieved metadata should not be nil")
// Verify no corruption - this was failing in issue #116 with "invalid character 'e'" error
assert.Equal(t, metadata.Issuer, retrieved.Issuer)
assert.Equal(t, metadata.AuthURL, retrieved.AuthURL)
assert.Equal(t, metadata.TokenURL, retrieved.TokenURL)
assert.Equal(t, metadata.JWKSURL, retrieved.JWKSURL)
// Verify the data is not Base64-encoded in Redis
// This checks the root cause mentioned in the issue
ctx := context.Background()
rawData, _, exists, err := backend.Get(ctx, "metadata:v2:"+providerURL)
require.NoError(t, err)
require.True(t, exists)
// Strip the marker byte
require.Greater(t, len(rawData), 1, "Data should have marker byte")
dataWithoutMarker := rawData[1:]
// Should not start with "eyJ" (Base64 encoding of "{")
if len(dataWithoutMarker) >= 3 {
assert.NotEqual(t, "eyJ", string(dataWithoutMarker[:3]), "Data should not be Base64-encoded")
}
// Should be valid JSON
var checkMetadata ProviderMetadata
err = json.Unmarshal(dataWithoutMarker, &checkMetadata)
require.NoError(t, err, "Stored data should be valid JSON, not Base64")
assert.Equal(t, metadata.Issuer, checkMetadata.Issuer)
}
+108 -51
View File
@@ -13,20 +13,23 @@ import (
// It runs a single consolidated cleanup goroutine for all caches, reducing
// goroutine count and CPU overhead compared to per-cache cleanup routines.
type UniversalCacheManager struct {
sharedBackend backends.CacheBackend
ctx context.Context
tokenTypeCache *UniversalCache
jwkCache *UniversalCache
sessionCache *UniversalCache
introspectionCache *UniversalCache
tokenCache *UniversalCache
metadataCache *UniversalCache
logger *Logger
blacklistCache *UniversalCache
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
cleanupStarted bool
sharedBackend backends.CacheBackend
ctx context.Context
tokenTypeCache *UniversalCache
jwkCache *UniversalCache
sessionCache *UniversalCache
introspectionCache *UniversalCache
tokenCache *UniversalCache
metadataCache *UniversalCache
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
refreshResultCache *UniversalCache // Short-lived cross-replica refresh-result dedup (paired with RefreshCoordinator)
logger *Logger
blacklistCache *UniversalCache
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
cleanupStarted bool
}
var (
@@ -169,6 +172,28 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
})
// Initialize session invalidation cache for backchannel/front-channel logout
// This cache stores invalidated session IDs and subjects to revoke sessions
manager.sessionInvalidationCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeSession,
MaxSize: 5000, // Support many concurrent invalidations
DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
})
// Refresh-result cache: short-lived store keyed by sha256(refreshToken).
// In Redis-backed mode this gives cross-replica dedup of refresh grants;
// in memory-only mode it's effectively redundant with RefreshCoordinator
// but safe and cheap to keep.
manager.refreshResultCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000,
DefaultTTL: 5 * time.Second,
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
})
}
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
@@ -185,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
RedisPrefix: redisConfig.KeyPrefix,
PoolSize: redisConfig.PoolSize,
EnableMetrics: true,
EnableTLS: redisConfig.EnableTLS,
TLSSkipVerify: redisConfig.TLSSkipVerify,
}
// Use concrete type to avoid Yaegi reflection issues with interface assignment
@@ -349,6 +376,47 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
SkipAutoCleanup: true, // Managed cleanup
})
// DCR credentials cache - CRITICAL for distributed DCR across multiple nodes
// Uses Redis backend to share client credentials across all Traefik replicas
manager.dcrCredentialsCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100, // Few providers expected
DefaultTTL: 30 * 24 * time.Hour, // 30 days default (credentials are long-lived)
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
},
createBackend("dcr"),
)
// Session invalidation cache - CRITICAL for distributed backchannel/front-channel logout
// Uses Redis backend to share session invalidations across all Traefik replicas
manager.sessionInvalidationCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeSession,
MaxSize: 5000, // Support many concurrent invalidations
DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
},
createBackend("session_invalidation"),
)
// Refresh-result cache - shared via Redis so concurrent refreshes across
// Traefik replicas can dedup their grants. The 5s TTL is long enough for
// peers to observe a recent refresh and short enough that a stale entry
// can't be replayed against a now-rotated refresh token.
manager.refreshResultCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000,
DefaultTTL: 5 * time.Second,
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
},
createBackend("refresh_result"),
)
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
}
@@ -396,6 +464,9 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
m.sessionCache,
m.introspectionCache,
m.tokenTypeCache,
m.dcrCredentialsCache,
m.sessionInvalidationCache,
m.refreshResultCache,
}
m.mu.RUnlock()
@@ -437,13 +508,6 @@ func (m *UniversalCacheManager) GetJWKCache() *UniversalCache {
return m.jwkCache
}
// GetSessionCache returns the session cache
func (m *UniversalCacheManager) GetSessionCache() *UniversalCache {
m.mu.RLock()
defer m.mu.RUnlock()
return m.sessionCache
}
// GetIntrospectionCache returns the token introspection cache
func (m *UniversalCacheManager) GetIntrospectionCache() *UniversalCache {
m.mu.RLock()
@@ -458,6 +522,28 @@ func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache {
return m.tokenTypeCache
}
// GetSessionInvalidationCache returns the session invalidation cache for backchannel/front-channel logout
func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
m.mu.RLock()
defer m.mu.RUnlock()
return m.sessionInvalidationCache
}
// GetRefreshResultCache returns the short-lived refresh-result cache used to
// coalesce refresh-token grants across Traefik replicas.
func (m *UniversalCacheManager) GetRefreshResultCache() *UniversalCache {
m.mu.RLock()
defer m.mu.RUnlock()
return m.refreshResultCache
}
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
m.mu.RLock()
defer m.mu.RUnlock()
return m.dcrCredentialsCache
}
// Close shuts down all caches and the consolidated cleanup routine
func (m *UniversalCacheManager) Close() error {
// Stop the consolidated cleanup routine first
@@ -473,7 +559,7 @@ func (m *UniversalCacheManager) Close() error {
// Close all caches first (they won't close the shared backend)
for _, cache := range []*UniversalCache{
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache,
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, m.refreshResultCache,
} {
if cache != nil {
_ = cache.Close() // Safe to ignore: best effort cache cleanup
@@ -494,35 +580,6 @@ func (m *UniversalCacheManager) Close() error {
return nil
}
// InitializeCacheManagerFromConfig initializes the cache manager with configuration
// This should be called early in the application startup with the loaded configuration
func InitializeCacheManagerFromConfig(config *Config) *UniversalCacheManager {
logger := NewLogger(config.LogLevel)
// Initialize Redis config if not present
if config.Redis == nil {
config.Redis = &RedisConfig{}
}
// Apply environment variable fallbacks for fields not set in config
// This allows env vars to be used as optional overrides only when
// the config field is not explicitly set through Traefik
config.Redis.ApplyEnvFallbacks()
// Apply defaults after env fallbacks
config.Redis.ApplyDefaults()
// Log cache backend selection
if config.Redis != nil && config.Redis.Enabled {
logger.Infof("Initializing cache backend with Redis: mode=%s, address=%s",
config.Redis.CacheMode, config.Redis.Address)
} else {
logger.Info("Initializing cache backend with memory-only mode")
}
return GetUniversalCacheManagerWithConfig(logger, config.Redis)
}
// ResetUniversalCacheManagerForTesting resets the singleton for testing purposes only
// This should only be called in test code to ensure proper cleanup between tests
func ResetUniversalCacheManagerForTesting() {
+5
View File
@@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error {
t.safeLogDebug("metadataRefreshStopChan closed")
}
if t.refreshCoordinator != nil {
t.refreshCoordinator.Shutdown()
t.safeLogDebug("refreshCoordinator shut down")
}
if t.goroutineWG != nil {
done := make(chan struct{})
go func() {

Some files were not shown because too many files have changed in this diff Show More