Compare commits

..

15 Commits

Author SHA1 Message Date
lukaszraczylo ccbb98b9dd fix-issue-122 (#128) 2026-03-04 00:23:30 +00:00
Serhii Vasyliev 1362cc0dac Improve debug logging around callback URL matching (#126)
* Add debug logging around callback URL matching in ServeHTTP

The callback URL comparison at the core of OIDC flow had zero logging,
making it extremely difficult to diagnose redirect loop issues caused
by misconfigured callbackURL (e.g., full URL vs path-only).

Every other path comparison in ServeHTTP already logs debug info
(logout, backchannel, frontchannel, excluded URLs), but the callback
URL check was completely silent.

Added debug logs that show:
- The values being compared (request path vs configured callback)
- Whether the match succeeded or failed
- Configured redirURLPath during initialization

This would have immediately revealed the root cause of issue #1
where callbackURL was set as a full URL but compared against
req.URL.Path which only contains the path component.

Closes #3

* improve-callback-url-logging: Add init-time logging for callbackURL config
2026-02-23 10:36:37 +00:00
Yuval Bar-On 249dcad1b3 fix: prevent deadlock in SessionData.Clear method (#114)
Move mutex unlock before calling Save() to prevent potential deadlock
when Save() method needs to acquire the same mutex.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-16 15:02:33 +00:00
lukaszraczylo de4b4d7258 fix(cache): remove sync.Pool for Yaegi compatibility (#121)
- [x] Remove sync.Pool implementation that causes reflection panics
- [x] Replace pool-based NewRESPWriter with direct instantiation
- [x] Replace pool-based NewRESPReader with direct instantiation
- [x] Convert Release() methods to no-ops for API compatibility
- [x] Add documentation explaining sync.Pool removal for Yaegi
- [x] Remove "sync" import

Resolves #120
2026-01-19 17:52:31 +00:00
lukaszraczylo 9d52f1b018 feat(core): refactor linters config and improve code quality (#119)
- [x] Reorganize golangci-lint configuration with documented disable reasons
- [x] Simplify errcheck and revive linter rules with targeted exclusions
- [x] Pre-compile regex patterns in input_validation.go for performance
- [x] Fix type assertions in memory_shard.go and resp.go with safety checks
- [x] Replace string comparison with EqualFold for case-insensitive matching
- [x] Fix loop variable captures in jwk.go and logout.go
- [x] Change high goroutine log level from Info to Debug in autocleanup.go
- [x] Replace deprecated "cancelled" spelling with "canceled" throughout
- [x] Add nolint annotations for intentional unused parameters
- [x] Improve comment formatting for deprecated functions
- [x] Fix comment spelling: "marshalling" → "marshaling"
- [x] Refactor provider warnings formatting in internal/providers/warnings.go
- [x] Simplify metrics summary building in internal/recovery/metrics.go
- [x] Pre-allocate slice in error_recovery.go GetDegradedServices
- [x] Refactor context cancellation checks in redis.go
2026-01-15 10:40:49 +00:00
lukaszraczylo 57724918fe fix 116 (#118)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases

* docs: add security fix documentation for integer overflow protection

* test: fix goroutine tests to use mock OIDC servers

The TestContextAwareGoroutineManagement tests were making real HTTP
calls to hardcoded URLs like https://example.com, causing failures
in CI when those requests timeout or return HTTP errors.

Changes:
- Added createMockOIDCServer() helper function using httptest
- Updated GoroutineCleanupOnContextCancel to use mock server
- Updated NoGoroutineLeakOnMultipleInstances to use 3 mock servers
- Updated SingletonTasksAcrossInstances to use mock servers array

This prevents network calls and makes tests more reliable and faster.

Fixes test failures in GitHub Actions CI.
2026-01-08 22:50:46 +00:00
lukaszraczylo 775de2ada1 Fix cache serialisation (#117)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases
2026-01-08 22:06:19 +00:00
lukaszraczylo 7816e05c98 fix issue with logout url (#112)
* fix(logout): handle logout requests before OIDC initialization

- [x] Add debug logging to logout handler entry point
- [x] Move logout path check before OIDC initialization to enable logout when provider unavailable
- [x] Move excluded URL and SSE checks before initialization wait
- [x] Add debug logging for initialization wait to diagnose hanging requests
- [x] Add test for logout functionality without OIDC provider availability

* feat(logout): implement OIDC backchannel and front-channel logout

- [x] Add logout token validation and backchannel logout handler
- [x] Add front-channel logout handler with iframe support
- [x] Implement session invalidation cache for distributed deployments
- [x] Add comprehensive logout token claim verification (issuer, audience, events, iat, sid/sub)
- [x] Integrate session invalidation checks into authorization flow
- [x] Add configuration options for enabling backchannel/front-channel logout
- [x] Add extensive test coverage for logout flows and edge cases
- [x] Update documentation with logout configuration examples
- [x] Add middleware routing for logout endpoints
- [x] Extend cache manager with session invalidation cache support

Resolves #110

* fixup! feat(logout): implement OIDC backchannel and front-channel logout

* fixup! Merge branch 'main' into fix-issue-with-logout-url
2026-01-04 01:59:50 +00:00
Dominik Chilla 8bf7998150 Fix for Hashicorp Vault - accept opaque access tokens with dot-characters (#113) 2026-01-02 16:42:22 +00:00
muffn_ 22c4323fcb fix: set X-Forwarded-User header for SSE requests from existing session (#111)
Co-authored-by: muffin <MonsterMuffin@users.noreply.github.com>
2026-01-02 02:50:11 +00:00
lukaszraczylo 06b219d1f8 feat(dcr): Add Redis storage support for multi-replica deployments (#109)
- [x] Add file and Redis storage backends for DCR credentials
- [x] Implement storage abstraction with FileStore and RedisStore
- [x] Add factory function for automatic backend selection (auto/file/redis)
- [x] Integrate DCR credentials cache into UniversalCacheManager
- [x] Add comprehensive tests for storage backends and factory
- [x] Update configuration schema with storage backend options
- [x] Update documentation with multi-replica deployment guidance
- [x] Add Redis key prefix configuration for credential isolation
2025-12-31 12:52:39 +00:00
lukaszraczylo 413e4a1b7d LRU + cache conflicts prevention. (#104)
* LRU + cache conflicts prevention.

* Bugfix universalCache flooding ( issue #105 )

  1. Traefik cancels the context for old plugin instances
  2. Each plugin's Close() method is called
  3. The CacheInterfaceWrapper.Close() was calling cache.Close() on the shared singleton caches
  4. Each Close() triggered Clear() which logged "Cleared all items" at INFO level
2025-12-24 18:54:39 +00:00
lukaszraczylo 69e0d98c67 fixup! Add signing of the plugin on release. 2025-12-24 12:33:33 +00:00
lukaszraczylo 6d893df12b Add signing of the plugin on release. 2025-12-15 00:38:35 +00:00
lukaszraczylo 6efb78b7a8 Smarter approach to the cookies (#103)
* Smarter approach to the cookies

  - Single maxCookieSize = 1400 constant with clear documentation
  - Combined cookie storage for ~40-45% size reduction
  - Backward compatible migration from legacy cookies

* Tuneup the code.
2025-12-12 18:35:06 +00:00
132 changed files with 9887 additions and 2163 deletions
+2
View File
@@ -11,7 +11,9 @@ on:
workflow_dispatch:
permissions:
id-token: write
contents: write
packages: write
jobs:
release:
+1
View File
@@ -1,3 +1,4 @@
docker/
.claude/*.out
*.test
.leann/
+49 -32
View File
@@ -14,21 +14,22 @@ linters:
- gosec
- misspell
- noctx
- nolintlint
- prealloc
- revive
- rowserrcheck
- sqlclosecheck
- unconvert
- unparam
- whitespace
disable:
- exhaustive
- funlen
- gocognit
- gocyclo # Disabled: OAuth/OIDC flows are inherently complex
- goprintffuncname # Disabled: naming convention is project-specific
- lll
- mnd
- testpackage
- whitespace # Disabled: style preference about newlines
- wsl
settings:
dupl:
@@ -47,29 +48,13 @@ linters:
- fmt.Fprintln
goconst:
min-len: 3
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
min-occurrences: 15 # Increased to reduce noise for standard OAuth2/OIDC strings and common patterns like "true"
ignore-tests: true
gocritic:
# Using default enabled checks in v2
enabled-checks:
- appendCombine
- boolExprSimplify
- builtinShadow
- commentedOutCode
- emptyFallthrough
- equalFold
- hexLiteral
- indexAlloc
- initClause
- methodExprCall
- nestingReduce
- rangeExprCopy
- rangeValCopy
- stringXbytes
- typeAssertChain
- typeUnparen
- unlabelStmt
- yodaStyleExpr
# Disable style-only checks that add noise
disabled-checks:
- ifElseChain # Style preference, switch not always clearer
- elseif # Style preference
gocyclo:
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
gosec:
@@ -106,23 +91,23 @@ linters:
- name: error-return
- name: error-strings
- name: error-naming
- name: exported
- name: if-return
# - name: exported # Disabled: too noisy, not all exported functions need comments
# - name: if-return # Disabled: style preference
- name: increment-decrement
- name: var-naming
- name: var-declaration
- name: package-comments
# - name: var-naming # Disabled: too strict for legacy code (IP vs Ip)
# - name: var-declaration # Disabled: explicit zero values can be clearer
# - name: package-comments # Disabled: handled by other tools
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
- name: indent-error-flow
# - name: indent-error-flow # Disabled: style preference
- name: errorf
- name: empty-block
# - name: empty-block # Disabled: sometimes empty blocks are intentional
- name: superfluous-else
- name: unused-parameter
# - name: unused-parameter # Disabled: test callbacks and interface implementations often have required unused params
- name: unreachable-code
- name: redefines-builtin-id
# - name: redefines-builtin-id # Disabled: min/max helpers are common before Go 1.21
unparam:
check-exported: false
staticcheck:
@@ -132,8 +117,15 @@ linters:
- -QF1003 # Tagged switch - style preference, may affect Yaegi
- -QF1007 # Merge conditional assignment - style preference
- -QF1008 # Remove embedded field - may break Yaegi compatibility
- -QF1011 # Omit type from declaration - style preference
- -QF1012 # Use fmt.Fprintf - style preference
- -SA9003 # Empty branch - sometimes intentional for future work
- -ST1000 # Package comment format - not required for all packages
- -ST1003 # Package name format - allowed for test packages
- -ST1016 # Receiver name consistency - legacy code
- -ST1020 # Comment format for methods - style preference
- -ST1021 # Comment format for types - style preference
- -ST1023 # Omit type from declaration - style preference
exclusions:
generated: lax
rules:
@@ -144,18 +136,43 @@ linters:
- goconst
- gocyclo
- gosec
- govet
- ineffassign
- noctx
- prealloc
- unparam
- revive
- gocritic
path: _test\.go
- linters:
- dupl
- gocyclo
- govet
- noctx
- prealloc
- unparam
- revive
- gocritic
path: test.*\.go
- linters:
- gocritic
- unused
- errcheck
- revive
path: mocks.*\.go
- linters:
- errcheck
- revive
- gocritic
- govet
- unparam
path: internal/testutil/
- linters:
- govet
- unparam
- noctx
- prealloc
path: integration/
- linters:
- gosec
text: 'G404:'
+11
View File
@@ -47,3 +47,14 @@ release:
name_template: "v{{ .Version }}"
draft: false
prerelease: auto
signs:
- cmd: cosign
signature: "${artifact}.sigstore.json"
args:
- sign-blob
- "--bundle=${signature}"
- "${artifact}"
- "--yes"
artifacts: checksum
output: true
+334
View File
@@ -123,6 +123,7 @@ testData:
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
stripAuthCookies: false # Strip OIDC session cookies before forwarding to backend (default: false)
# Security Headers Configuration (enabled by default with 'default' profile)
securityHeaders:
@@ -1021,6 +1022,108 @@ configuration:
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
required: false
stripAuthCookies:
type: boolean
description: |
Strip OIDC session cookies from the request before forwarding to backend services.
When enabled, the middleware removes all cookies with the OIDC prefix (default: _oidc_raczylo_)
from the Cookie header before the request is forwarded to the backend. The cookies remain
in the browser and are still sent to Traefik for session management — they are only removed
from the Traefik-to-backend hop.
This prevents "431 Request Header Fields Too Large" errors caused by large OIDC session
cookies (which can reach ~28KB with token chunking) being forwarded to backend services
that have limited header buffer sizes.
Non-OIDC cookies (application sessions, preferences, etc.) are always passed through
untouched.
Use this option when:
- Backend services return "431 Request Header Fields Too Large" errors
- OIDC session cookies are large due to token chunking
- Backend services don't need OIDC session cookies
- You want to reduce Cookie header overhead on backend requests
Can be combined with minimalHeaders for maximum header size reduction.
Default: false (all cookies forwarded for backward compatibility)
See: https://github.com/lukaszraczylo/traefikoidc/issues/122
required: false
enableBackchannelLogout:
type: boolean
description: |
Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST).
When enabled, the middleware accepts logout tokens at the configured backchannelLogoutURL.
The IdP sends a signed JWT (logout_token) to notify the application that a user's session
should be terminated.
This implements the OIDC Back-Channel Logout 1.0 specification.
See: https://openid.net/specs/openid-connect-backchannel-1_0.html
Requirements:
- backchannelLogoutURL must be configured
- The IdP must be configured to send logout tokens to your backchannel URL
- Logout tokens are validated using the IdP's JWKS
Default: false
required: false
backchannelLogoutURL:
type: string
description: |
Path for receiving backchannel logout tokens from the IdP.
This endpoint receives POST requests with a logout_token JWT in the request body.
The token is validated against the IdP's JWKS and contains the session ID (sid)
and/or subject (sub) to invalidate.
Example: /backchannel-logout
The full URL to configure in your IdP would be:
https://your-app.example.com/backchannel-logout
Note: This path should be unique and not conflict with your application routes.
required: false
enableFrontchannelLogout:
type: boolean
description: |
Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe).
When enabled, the middleware accepts logout requests at the configured frontchannelLogoutURL.
The IdP embeds an iframe pointing to this URL when the user logs out, allowing the
application to clear the user's session.
This implements the OIDC Front-Channel Logout 1.0 specification.
See: https://openid.net/specs/openid-connect-frontchannel-1_0.html
Requirements:
- frontchannelLogoutURL must be configured
- The IdP must be configured with your front-channel logout URL
- Your CSP headers must allow being embedded in an iframe from the IdP
Default: false
required: false
frontchannelLogoutURL:
type: string
description: |
Path for receiving front-channel logout requests from the IdP.
This endpoint receives GET requests with optional sid (session ID) and iss (issuer)
query parameters. When called, it invalidates the user's session.
Example: /frontchannel-logout
The full URL to configure in your IdP would be:
https://your-app.example.com/frontchannel-logout
Note: This path should be unique and not conflict with your application routes.
required: false
headers:
type: array
description: |
@@ -1630,3 +1733,234 @@ configuration:
Default: 30 seconds
required: false
dynamicClientRegistration:
type: object
description: |
Configuration for OIDC Dynamic Client Registration (RFC 7591/7592).
Dynamic Client Registration allows the middleware to automatically register
itself as an OAuth 2.0 client with the OIDC provider, eliminating the need
to manually create and manage client credentials.
This is particularly useful for:
- Automated deployments where manual client creation is impractical
- Multi-tenant scenarios requiring per-deployment client isolation
- Development and testing environments
- Kubernetes environments with multiple replicas
For multi-replica deployments (Kubernetes), enable Redis storage to share
credentials across all instances and prevent registration race conditions.
Example configuration:
```yaml
dynamicClientRegistration:
enabled: true
persistCredentials: true
storageBackend: "redis" # Use Redis for distributed storage
clientMetadata:
redirect_uris:
- https://app.example.com/oauth2/callback
client_name: "My Application"
application_type: "web"
```
required: false
properties:
enabled:
type: boolean
description: |
Enable dynamic client registration with the OIDC provider.
When enabled and clientID is not set, the middleware will automatically
register itself with the provider using the configuration in clientMetadata.
Default: false
required: false
persistCredentials:
type: boolean
description: |
Enable persistence of client credentials after registration.
When enabled, credentials are saved to the configured storage backend
and reloaded on restart to avoid re-registration.
Default: false
required: false
storageBackend:
type: string
description: |
Storage backend for persisting DCR credentials.
Options:
- "file": Store credentials in a local file (default for backward compatibility)
- "redis": Store credentials in Redis (recommended for multi-replica deployments)
- "auto": Use Redis if available, fall back to file storage
For Kubernetes deployments with multiple replicas, use "redis" to ensure
all instances share the same client credentials and prevent registration
race conditions where each replica registers its own client.
Default: "auto"
required: false
enum:
- file
- redis
- auto
credentialsFile:
type: string
description: |
Path to store client credentials when using file-based storage.
The file will be created with restrictive permissions (0600).
Default: "/tmp/oidc-client-credentials.json"
required: false
redisKeyPrefix:
type: string
description: |
Prefix for Redis keys when using Redis storage.
Useful for isolating credentials between different applications
or environments sharing the same Redis instance.
Default: "dcr:creds:"
required: false
registrationEndpoint:
type: string
description: |
Override the registration endpoint URL.
If not specified, the endpoint will be discovered from provider metadata.
Some providers may not advertise their registration endpoint in metadata,
in which case you need to specify it explicitly.
Example: "https://auth.example.com/oauth/register"
required: false
initialAccessToken:
type: string
description: |
Initial Access Token for protected registration endpoints.
Some providers require an access token to authorize client registration.
If your provider requires authentication for registration, obtain an
initial access token from the provider and configure it here.
For Kubernetes, you can use secret references:
urn:k8s:secret:namespace:secret-name:key
required: false
clientMetadata:
type: object
description: |
Client metadata to include in the registration request (RFC 7591).
This defines the properties of the OAuth 2.0 client to be registered.
required: false
properties:
redirect_uris:
type: array
description: |
Array of redirect URIs for the client. Required for registration.
These must match the callback URLs that will be used in authentication flows.
Example: ["https://app.example.com/oauth2/callback"]
required: true
items:
type: string
client_name:
type: string
description: |
Human-readable name of the client.
This is typically displayed in consent screens.
Example: "My Application"
required: false
application_type:
type: string
description: |
Type of application. Affects security defaults.
Options:
- "web": Server-side web application (default)
- "native": Native/mobile application
Default: "web"
required: false
grant_types:
type: array
description: |
OAuth 2.0 grant types the client will use.
Default: ["authorization_code", "refresh_token"]
required: false
items:
type: string
response_types:
type: array
description: |
OAuth 2.0 response types the client will use.
Default: ["code"]
required: false
items:
type: string
token_endpoint_auth_method:
type: string
description: |
Authentication method for the token endpoint.
Options:
- "client_secret_basic": HTTP Basic authentication (default)
- "client_secret_post": Client credentials in POST body
- "none": Public client (no authentication)
Default: "client_secret_basic"
required: false
scope:
type: string
description: |
Space-separated list of scopes the client is authorized to request.
Example: "openid profile email"
required: false
contacts:
type: array
description: |
Array of contact email addresses for the client administrator.
Example: ["admin@example.com"]
required: false
items:
type: string
logo_uri:
type: string
description: |
URL to the client's logo image for consent screens.
required: false
client_uri:
type: string
description: |
URL to the client's home page.
required: false
policy_uri:
type: string
description: |
URL to the client's privacy policy.
required: false
tos_uri:
type: string
description: |
URL to the client's terms of service.
required: false
+86 -1
View File
@@ -8,7 +8,7 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
- **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more
- **Automatic provider detection**: Automatically detects and configures provider-specific settings
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration, with Redis storage support for multi-replica deployments
- **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes
- **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles
- **Domain restrictions**: Limit access to specific email domains or individual users
@@ -82,6 +82,19 @@ experimental:
2. Configure the middleware in your dynamic configuration (see examples below).
### Verifying Release Signatures
All release checksums are signed with [cosign](https://github.com/sigstore/cosign) using keyless signing. To verify:
```bash
# Download the checksum file and its sigstore bundle from the release
cosign verify-blob \
--certificate-identity-regexp "https://github.com/lukaszraczylo/traefikoidc/.*" \
--certificate-oidc-issuer "https://token.actions.githubusercontent.com" \
--bundle "traefikoidc_v<version>_checksums.txt.sigstore.json" \
traefikoidc_v<version>_checksums.txt
```
### Local Development with Docker Compose
For local development or testing, you can use the provided Docker Compose setup:
@@ -141,6 +154,11 @@ The middleware supports the following configuration options:
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
| `stripAuthCookies` | Strip OIDC session cookies before forwarding to backend services | `false` | `true` |
| `enableBackchannelLogout` | Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST) | `false` | `true` |
| `backchannelLogoutURL` | The path for receiving backchannel logout tokens from the IdP | none | `/backchannel-logout` |
| `enableFrontchannelLogout` | Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe) | `false` | `true` |
| `frontchannelLogoutURL` | The path for receiving front-channel logout requests from the IdP | none | `/frontchannel-logout` |
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
@@ -1135,6 +1153,50 @@ spec:
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
```
### With IdP-Initiated Logout (Backchannel & Front-Channel)
This plugin supports [OIDC Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) and [OIDC Front-Channel Logout](https://openid.net/specs/openid-connect-frontchannel-1_0.html) for IdP-initiated single logout.
**Backchannel Logout** (recommended): The IdP sends a server-to-server POST request with a signed `logout_token` JWT when a user logs out.
**Front-Channel Logout**: The IdP loads an iframe with the logout URL to invalidate the session in the browser.
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-idp-logout
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://auth.example.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout # RP-initiated logout
# Backchannel Logout (server-to-server)
enableBackchannelLogout: true
backchannelLogoutURL: /backchannel-logout
# Front-Channel Logout (iframe-based)
enableFrontchannelLogout: true
frontchannelLogoutURL: /frontchannel-logout
# For multi-replica deployments, use Redis to share session invalidations
redis:
enabled: true
address: redis:6379
```
> **Note**: For multi-replica deployments, you **must** enable Redis to share session invalidation state across all instances. Otherwise, a logout on one instance won't invalidate sessions on other instances.
**IdP Configuration**: Configure your IdP to send logout requests to:
- **Backchannel**: `https://your-app.example.com/backchannel-logout` (POST with `logout_token`)
- **Front-Channel**: `https://your-app.example.com/frontchannel-logout?sid=SESSION_ID&iss=ISSUER` (GET in iframe)
### With Templated Headers
```yaml
@@ -1709,6 +1771,29 @@ This is particularly useful when:
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
#### Strip Auth Cookies Mode
If your backend services return **"431 Request Header Fields Too Large"** errors due to large OIDC session cookies (which can reach ~28KB with token chunking), you can strip them before forwarding:
```yaml
http:
middlewares:
my-auth:
plugin:
traefikoidc:
stripAuthCookies: true
# ... other config
```
When `stripAuthCookies: true` is set:
- **Strips**: All OIDC session cookies (`_oidc_raczylo_*`) from the request before forwarding to the backend
- **Preserves**: All non-OIDC cookies (application sessions, preferences, etc.)
- **No browser impact**: Cookies remain in the browser and are still sent to Traefik for session management
This can be combined with `minimalHeaders: true` for maximum header size reduction.
See [GitHub Issue #122](https://github.com/lukaszraczylo/traefikoidc/issues/122) for details.
### Security Headers
The middleware also sets the following security headers:
+49
View File
@@ -0,0 +1,49 @@
# Security Fix: Integer Overflow Protection in Cache Serialization
## Summary
Fixed **High severity** integer overflow vulnerability identified by GitHub Advanced Security in PR #117.
## Vulnerability
**Locations**: `universal_cache.go` lines 789 and 811
- `result := make([]byte, len(bytes)+1)` - Raw bytes path
- `result := make([]byte, len(jsonData)+1)` - JSON encoding path
**Risk**: Potential integer overflow when allocating memory for very large cache entries.
## Fix Applied
1. **Added size limit constant**:
```go
maxCacheEntrySize = 64 * 1024 * 1024 // 64 MiB
```
2. **Size validation before allocation**:
- Validates entry size doesn't exceed limit
- Validates adding marker byte won't overflow
- Returns descriptive error messages
3. **Comprehensive test coverage**:
- Oversized byte slices (>64 MiB)
- Exact max size edge case
- Safe sizes (normal operation)
- Large JSON data structures
## Verification
✅ All tests pass with race detection
✅ No security issues (golangci-lint, gosec)
✅ 76.3% test coverage maintained
## Impact
- No breaking changes
- Negligible performance overhead
- Prevents potential buffer overflows
- Predictable memory usage
---
**Date**: January 8, 2026
**Severity**: High → Resolved
+4 -4
View File
@@ -84,8 +84,8 @@ func TestAudienceValidation(t *testing.T) {
tests := []struct {
name string
audience string
expectError bool
errorContains string
expectError bool
}{
{
name: "valid custom audience URL",
@@ -163,8 +163,8 @@ func TestConfigAudienceValidation(t *testing.T) {
tests := []struct {
name string
audience string
wantErr bool
errContains string
wantErr bool
}{
{
name: "Empty audience is valid for backward compatibility",
@@ -732,11 +732,11 @@ func TestJWTAudienceVerification(t *testing.T) {
tokenCache := tc.addTokenCache(NewTokenCache())
tests := []struct {
tokenAudience interface{}
name string
configAudience string
tokenAudience interface{}
wantErr bool
errContains string
wantErr bool
skipReplayCheck bool
}{
{
+1 -1
View File
@@ -253,8 +253,8 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication_WithPKCE()
// TestIsAjaxRequest tests AJAX request detection
func (s *AuthFlowBehaviourSuite) TestIsAjaxRequest() {
testCases := []struct {
name string
headers map[string]string
name string
expectAjax bool
}{
{
+14 -14
View File
@@ -222,17 +222,16 @@ func (bt *BackgroundTask) run() {
// TaskCircuitBreaker implements circuit breaker pattern for background task creation
// It limits concurrent task execution and tracks failures to prevent system overload
type TaskCircuitBreaker struct {
state int32 // CircuitBreakerState
failureCount int32
lastFailureTime int64 // Unix timestamp
failureThreshold int32
timeout time.Duration
logger *Logger
// Concurrency limiting
concurrentTasks int32 // Current number of running tasks
maxConcurrent int32 // Maximum concurrent tasks allowed
activeTasks map[string]struct{} // Track active task names
tasksMu sync.RWMutex // Separate mutex for task tracking
activeTasks map[string]struct{}
lastFailureTime int64
timeout time.Duration
tasksMu sync.RWMutex
state int32
failureCount int32
failureThreshold int32
concurrentTasks int32
maxConcurrent int32
}
// NewTaskCircuitBreaker creates a new circuit breaker for background tasks
@@ -380,9 +379,9 @@ func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
// TaskRegistry maintains a registry of all active background tasks to prevent duplicates
type TaskRegistry struct {
tasks map[string]*BackgroundTask
mu sync.RWMutex
cb *TaskCircuitBreaker
logger *Logger
mu sync.RWMutex
}
// GlobalTaskRegistry is the singleton instance for managing all background tasks
@@ -600,8 +599,9 @@ func GetGlobalTaskMemoryMonitor(logger *Logger) *TaskMemoryMonitor {
return globalTaskMemoryMonitor
}
// NewTaskMemoryMonitor creates a new memory monitor for task registry
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior
// NewTaskMemoryMonitor creates a new memory monitor for task registry.
//
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior.
func NewTaskMemoryMonitor(logger *Logger, registry *TaskRegistry) *TaskMemoryMonitor {
return GetGlobalTaskMemoryMonitor(logger)
}
@@ -713,7 +713,7 @@ func (mm *TaskMemoryMonitor) checkForMemoryIssues(stats TaskMemoryStats) {
// Check for goroutine leaks (arbitrary threshold)
if stats.Goroutines > 100 {
mm.logger.Infof("High goroutine count detected: %d", stats.Goroutines)
mm.logger.Debugf("High goroutine count detected: %d", stats.Goroutines)
}
// Check for heap growth without corresponding GC activity
+6 -6
View File
@@ -330,12 +330,12 @@ func TestValidateGoogleTokens(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
name string
setupSession func() *SessionData
name string
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "ValidGoogleTokens",
@@ -476,13 +476,13 @@ func TestIsUserAuthenticated(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
setupSession func() *SessionData
name string
providerType string
setupSession func() *SessionData
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "AzureProvider",
@@ -660,12 +660,12 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
ts.tOidc.refreshGracePeriod = 60 * time.Second
tests := []struct {
name string
setupSession func() *SessionData
name string
description string
expectedAuth bool
expectedRefresh bool
expectedExpired bool
description string
}{
{
name: "UnauthenticatedWithRefreshToken",
+7 -7
View File
@@ -97,15 +97,15 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
t.Run("String method returns pressure name", func(t *testing.T) {
pressures := []struct {
level MemoryPressureLevel
name string
level MemoryPressureLevel
}{
{MemoryPressureNone, "None"},
{MemoryPressureLow, "Low"},
{MemoryPressureModerate, "Moderate"},
{MemoryPressureHigh, "High"},
{MemoryPressureCritical, "Critical"},
{MemoryPressureLevel(999), "Unknown"},
{level: MemoryPressureNone, name: "None"},
{level: MemoryPressureLow, name: "Low"},
{level: MemoryPressureModerate, name: "Moderate"},
{level: MemoryPressureHigh, name: "High"},
{level: MemoryPressureCritical, name: "Critical"},
{level: MemoryPressureLevel(999), name: "Unknown"},
}
for _, p := range pressures {
+3 -3
View File
@@ -155,9 +155,9 @@ type CacheStrategy interface {
// CacheEntry for backward compatibility
type CacheEntry struct {
Key string
Value interface{}
ExpiresAt time.Time
Value interface{}
Key string
}
// Cache is an alias for backward compatibility
@@ -175,10 +175,10 @@ func NewOptimizedCacheWithConfig(config OptimizedCacheConfig) *CacheInterfaceWra
// ListNode for backward compatibility
type ListNode struct {
Key string
Value interface{}
Next *ListNode
Prev *ListNode
Key string
}
// NewFixedMetadataCache creates a metadata cache with fixed configuration
+24 -8
View File
@@ -20,8 +20,9 @@ var (
cacheManagerInitOnce sync.Once
)
// GetGlobalCacheManager returns a singleton CacheManager instance
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
// GetGlobalCacheManager returns a singleton CacheManager instance.
//
// Deprecated: Use GetGlobalCacheManagerWithConfig instead.
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
return GetGlobalCacheManagerWithConfig(wg, nil)
}
@@ -61,7 +62,7 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()}
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true}
}
// GetSharedTokenCache returns the shared token cache
@@ -93,7 +94,7 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()}
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true}
}
// GetSharedTokenTypeCache returns the shared token type cache
@@ -101,7 +102,15 @@ func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()}
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
}
// GetSharedSessionInvalidationCache returns the shared session invalidation cache
// for backchannel and front-channel logout (IdP-initiated logout)
func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
}
// Close gracefully shuts down all cache components
@@ -121,7 +130,8 @@ func CleanupGlobalCacheManager() error {
// CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface
type CacheInterfaceWrapper struct {
cache *UniversalCache
cache *UniversalCache
managed bool // If true, cache is managed globally and Close() is a no-op
}
// Set stores a value
@@ -149,9 +159,15 @@ func (c *CacheInterfaceWrapper) Cleanup() {
c.cache.Cleanup()
}
// Close shuts down the cache
// Close shuts down the cache if it's not managed globally.
// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding
// when multiple plugin instances are closed during Traefik configuration reloads.
func (c *CacheInterfaceWrapper) Close() {
// Close the underlying cache to stop goroutines
if c.managed {
// Cache is managed globally by UniversalCacheManager, so we don't close it here.
return
}
// Standalone cache - close it properly to stop cleanup goroutines
if c.cache != nil {
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
}
+164 -11
View File
@@ -19,16 +19,16 @@ import (
// CacheTestCase represents a comprehensive test case for cache operations
type CacheTestCase struct {
setup func(*TestFramework)
execute func(*TestFramework) error
validate func(*testing.T, error, *TestFramework)
cleanup func(*TestFramework)
name string
cacheType string // "universal", "metadata", "bounded"
operation string // "get", "set", "evict", "cleanup"
setup func(*TestFramework) // Pre-test setup
execute func(*TestFramework) error // Test execution
validate func(*testing.T, error, *TestFramework) // Validation logic
cleanup func(*TestFramework) // Post-test cleanup
timeout time.Duration // Test timeout
parallel bool // Can run in parallel
skipReason string // Optional reason to skip
cacheType string
operation string
skipReason string
timeout time.Duration
parallel bool
}
// createTestCacheConfig creates a standard test configuration
@@ -219,6 +219,159 @@ func TestCacheInterfaceWrapper_Close(t *testing.T) {
nilWrapper.Close()
}
// TestCacheInterfaceWrapper_ManagedClose_Regression tests that managed cache wrappers
// don't close the underlying cache when Close() is called. This is a regression test
// for issue #105 where multiple plugin instances closing shared caches caused log flooding.
func TestCacheInterfaceWrapper_ManagedClose_Regression(t *testing.T) {
cm := getTestCacheManager(t)
// Get a managed cache wrapper
cache := cm.GetSharedTokenBlacklist()
wrapper, ok := cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
// Verify it's marked as managed
if !wrapper.managed {
t.Error("Expected shared cache wrapper to be marked as managed")
}
// Set some data before Close
cache.Set("test-key", "test-value", time.Hour)
// Close the wrapper (should be a no-op for managed caches)
wrapper.Close()
// Verify the cache is still operational after Close
value, found := cache.Get("test-key")
if !found {
t.Error("Expected cache to still work after Close() on managed wrapper")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got %v", value)
}
// Can still set new values
cache.Set("new-key", "new-value", time.Hour)
newValue, found := cache.Get("new-key")
if !found || newValue != "new-value" {
t.Error("Expected to be able to set new values after Close() on managed wrapper")
}
}
// TestCacheInterfaceWrapper_StandaloneClose tests that standalone cache wrappers
// properly close the underlying cache when Close() is called.
func TestCacheInterfaceWrapper_StandaloneClose(t *testing.T) {
// Create a standalone cache (not from the global cache manager)
standaloneCache := NewCache()
wrapper, ok := standaloneCache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
// Verify it's NOT marked as managed
if wrapper.managed {
t.Error("Expected standalone cache wrapper to NOT be marked as managed")
}
// Set some data
standaloneCache.Set("test-key", "test-value", time.Hour)
// Get baseline goroutine count
baselineGoroutines := runtime.NumGoroutine()
// Close the wrapper (should actually close the underlying cache)
wrapper.Close()
// Give cleanup goroutine time to stop
time.Sleep(100 * time.Millisecond)
// Goroutine count should decrease (cleanup routine stopped)
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > baselineGoroutines {
// This is acceptable - other tests might have started goroutines
t.Logf("Goroutine count: baseline=%d, final=%d", baselineGoroutines, finalGoroutines)
}
}
// TestCacheInterfaceWrapper_MultipleInstancesClose_Regression tests that multiple
// plugin instances can close their cache wrappers without affecting shared caches.
// This is a regression test for issue #105.
func TestCacheInterfaceWrapper_MultipleInstancesClose_Regression(t *testing.T) {
cm := getTestCacheManager(t)
// Simulate multiple plugin instances getting cache references
instances := make([]*CacheInterfaceWrapper, 5)
for i := 0; i < 5; i++ {
cache := cm.GetSharedTokenBlacklist()
wrapper, ok := cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatal("Expected CacheInterfaceWrapper")
}
instances[i] = wrapper
// Each instance might set some data
cache.Set(fmt.Sprintf("instance-%d-key", i), fmt.Sprintf("value-%d", i), time.Hour)
}
// Close all instances (simulating plugin shutdown/reload)
for _, wrapper := range instances {
wrapper.Close()
}
// The shared cache should still work after all instances closed their wrappers
newCache := cm.GetSharedTokenBlacklist()
// Data set by earlier instances should still be accessible
for i := 0; i < 5; i++ {
key := fmt.Sprintf("instance-%d-key", i)
value, found := newCache.Get(key)
if !found {
t.Errorf("Expected data from instance %d to still be accessible", i)
}
expectedValue := fmt.Sprintf("value-%d", i)
if value != expectedValue {
t.Errorf("Expected '%s', got '%v'", expectedValue, value)
}
}
// Should be able to add new data
newCache.Set("after-close-key", "after-close-value", time.Hour)
value, found := newCache.Get("after-close-key")
if !found || value != "after-close-value" {
t.Error("Expected to be able to use cache after all wrapper Close() calls")
}
}
// TestAllSharedCachesMarkedAsManaged verifies all shared cache getters
// return managed wrappers to prevent the log flooding issue.
func TestAllSharedCachesMarkedAsManaged(t *testing.T) {
cm := getTestCacheManager(t)
tests := []struct {
name string
cache CacheInterface
}{
{"TokenBlacklist", cm.GetSharedTokenBlacklist()},
{"IntrospectionCache", cm.GetSharedIntrospectionCache()},
{"TokenTypeCache", cm.GetSharedTokenTypeCache()},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wrapper, ok := tt.cache.(*CacheInterfaceWrapper)
if !ok {
t.Fatalf("Expected CacheInterfaceWrapper for %s", tt.name)
}
if !wrapper.managed {
t.Errorf("%s cache wrapper should be marked as managed", tt.name)
}
})
}
}
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
cm := getTestCacheManager(t)
cache := cm.GetSharedTokenBlacklist()
@@ -698,10 +851,10 @@ func TestUnifiedCache_SetMaxSize(t *testing.T) {
func TestNewCacheAdapter(t *testing.T) {
tests := []struct {
name string
cache interface{}
expectNil bool
name string
description string
expectNil bool
}{
{
name: "UniversalCache",
+2 -2
View File
@@ -7,7 +7,7 @@ import (
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
// MarshalJSON implements custom JSON marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalJSON() ([]byte, error) {
// Build a map manually to avoid type alias issues with yaegi
@@ -47,7 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
return json.Marshal(result)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
// MarshalYAML implements custom YAML marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalYAML() (interface{}, error) {
// Build a map manually to avoid type alias issues with yaegi
+290
View File
@@ -0,0 +1,290 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"context"
"fmt"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/dcrstorage"
)
// DCRStorageBackend represents the type of storage backend for DCR credentials.
// Alias for internal package type for backward compatibility.
type DCRStorageBackend = dcrstorage.StorageBackend
const (
// DCRStorageBackendFile uses file-based storage (default for backward compatibility)
DCRStorageBackendFile DCRStorageBackend = dcrstorage.StorageBackendFile
// DCRStorageBackendRedis uses Redis for distributed storage
DCRStorageBackendRedis DCRStorageBackend = dcrstorage.StorageBackendRedis
// DCRStorageBackendAuto automatically selects Redis if available, otherwise file
DCRStorageBackendAuto DCRStorageBackend = dcrstorage.StorageBackendAuto
)
// DCRCredentialsStore defines the interface for storing DCR credentials.
// This abstraction allows different storage backends (file, Redis) to be used
// for persisting OIDC Dynamic Client Registration credentials across nodes.
type DCRCredentialsStore interface {
// Save stores the client registration response for a provider
// The providerURL is used as a key to support multi-tenant scenarios
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
// Load retrieves stored credentials for a provider
// Returns nil, nil if no credentials exist (not an error)
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
// Delete removes stored credentials for a provider
Delete(ctx context.Context, providerURL string) error
// Exists checks if credentials exist for a provider
Exists(ctx context.Context, providerURL string) (bool, error)
}
// loggerAdapter adapts our Logger to the dcrstorage.Logger interface
type loggerAdapter struct {
logger *Logger
}
func (l *loggerAdapter) Debug(msg string) { l.logger.Debug("%s", msg) }
func (l *loggerAdapter) Debugf(format string, args ...any) { l.logger.Debugf(format, args...) }
func (l *loggerAdapter) Info(msg string) { l.logger.Info("%s", msg) }
func (l *loggerAdapter) Infof(format string, args ...any) { l.logger.Infof(format, args...) }
func (l *loggerAdapter) Error(msg string) { l.logger.Error("%s", msg) }
func (l *loggerAdapter) Errorf(format string, args ...any) { l.logger.Errorf(format, args...) }
// cacheAdapter adapts UniversalCache to dcrstorage.Cache interface
type cacheAdapter struct {
cache *UniversalCache
}
func (c *cacheAdapter) Get(key string) (any, bool) {
return c.cache.Get(key)
}
func (c *cacheAdapter) Set(key string, value any, ttl time.Duration) error {
return c.cache.Set(key, value, ttl)
}
func (c *cacheAdapter) Delete(key string) {
c.cache.Delete(key)
}
// fileStoreWrapper wraps dcrstorage.FileStore to implement DCRCredentialsStore
type fileStoreWrapper struct {
inner *dcrstorage.FileStore
}
func (w *fileStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
innerCreds := convertCredsToInternal(creds)
return w.inner.Save(ctx, providerURL, innerCreds)
}
func (w *fileStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
innerCreds, err := w.inner.Load(ctx, providerURL)
if err != nil || innerCreds == nil {
return nil, err
}
return convertCredsFromInternal(innerCreds), nil
}
func (w *fileStoreWrapper) Delete(ctx context.Context, providerURL string) error {
return w.inner.Delete(ctx, providerURL)
}
func (w *fileStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
return w.inner.Exists(ctx, providerURL)
}
// basePath returns the base path used for storing credentials (for backward compatibility in tests)
func (w *fileStoreWrapper) basePath() string {
return w.inner.BasePath()
}
// getFilePath returns the file path for storing credentials for a specific provider (for backward compatibility in tests)
func (w *fileStoreWrapper) getFilePath(providerURL string) string {
return w.inner.GetFilePath(providerURL)
}
// redisStoreWrapper wraps dcrstorage.RedisStore to implement DCRCredentialsStore
type redisStoreWrapper struct {
inner *dcrstorage.RedisStore
}
func (w *redisStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
innerCreds := convertCredsToInternal(creds)
return w.inner.Save(ctx, providerURL, innerCreds)
}
func (w *redisStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
innerCreds, err := w.inner.Load(ctx, providerURL)
if err != nil || innerCreds == nil {
return nil, err
}
return convertCredsFromInternal(innerCreds), nil
}
func (w *redisStoreWrapper) Delete(ctx context.Context, providerURL string) error {
return w.inner.Delete(ctx, providerURL)
}
func (w *redisStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
return w.inner.Exists(ctx, providerURL)
}
// FileCredentialsStore implements DCRCredentialsStore using file-based storage.
// This is the default storage backend for backward compatibility with existing deployments.
type FileCredentialsStore = fileStoreWrapper
// RedisCredentialsStore implements DCRCredentialsStore using Redis-backed cache.
// This storage backend enables sharing DCR credentials across multiple Traefik instances.
type RedisCredentialsStore = redisStoreWrapper
// NewFileCredentialsStore creates a new file-based credentials store.
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
func NewFileCredentialsStore(basePath string, logger *Logger) *FileCredentialsStore {
var dcrLogger dcrstorage.Logger
if logger != nil {
dcrLogger = &loggerAdapter{logger: logger}
}
inner := dcrstorage.NewFileStore(basePath, dcrLogger)
return &fileStoreWrapper{inner: inner}
}
// NewRedisCredentialsStore creates a new Redis-backed credentials store.
// The cache should be configured with a Redis backend for distributed storage.
// If keyPrefix is empty, defaults to "dcr:creds:"
func NewRedisCredentialsStore(cache *UniversalCache, keyPrefix string, logger *Logger) *RedisCredentialsStore {
var dcrLogger dcrstorage.Logger
if logger != nil {
dcrLogger = &loggerAdapter{logger: logger}
}
cacheAdapt := &cacheAdapter{cache: cache}
inner := dcrstorage.NewRedisStore(cacheAdapt, keyPrefix, dcrLogger)
return &redisStoreWrapper{inner: inner}
}
// Helper functions to convert between main package and internal package types
func convertCredsToInternal(creds *ClientRegistrationResponse) *dcrstorage.ClientRegistrationResponse {
if creds == nil {
return nil
}
return &dcrstorage.ClientRegistrationResponse{
SubjectType: creds.SubjectType,
LogoURI: creds.LogoURI,
RegistrationAccessToken: creds.RegistrationAccessToken,
RegistrationClientURI: creds.RegistrationClientURI,
Scope: creds.Scope,
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
TOSURI: creds.TOSURI,
PolicyURI: creds.PolicyURI,
ClientSecret: creds.ClientSecret,
ApplicationType: creds.ApplicationType,
ClientID: creds.ClientID,
ClientName: creds.ClientName,
JWKSURI: creds.JWKSURI,
ClientURI: creds.ClientURI,
Contacts: creds.Contacts,
GrantTypes: creds.GrantTypes,
ResponseTypes: creds.ResponseTypes,
RedirectURIs: creds.RedirectURIs,
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
ClientIDIssuedAt: creds.ClientIDIssuedAt,
}
}
func convertCredsFromInternal(creds *dcrstorage.ClientRegistrationResponse) *ClientRegistrationResponse {
if creds == nil {
return nil
}
return &ClientRegistrationResponse{
SubjectType: creds.SubjectType,
LogoURI: creds.LogoURI,
RegistrationAccessToken: creds.RegistrationAccessToken,
RegistrationClientURI: creds.RegistrationClientURI,
Scope: creds.Scope,
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
TOSURI: creds.TOSURI,
PolicyURI: creds.PolicyURI,
ClientSecret: creds.ClientSecret,
ApplicationType: creds.ApplicationType,
ClientID: creds.ClientID,
ClientName: creds.ClientName,
JWKSURI: creds.JWKSURI,
ClientURI: creds.ClientURI,
Contacts: creds.Contacts,
GrantTypes: creds.GrantTypes,
ResponseTypes: creds.ResponseTypes,
RedirectURIs: creds.RedirectURIs,
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
ClientIDIssuedAt: creds.ClientIDIssuedAt,
}
}
// NewDCRCredentialsStore creates a DCRCredentialsStore based on configuration.
// This factory function handles backend selection logic:
// - "file": Use file-based storage (default for backward compatibility)
// - "redis": Use Redis exclusively (fails if Redis unavailable)
// - "auto": Use Redis if available, fallback to file
func NewDCRCredentialsStore(
config *DynamicClientRegistrationConfig,
cacheManager *CacheManager,
logger *Logger,
) (DCRCredentialsStore, error) {
if config == nil {
return nil, fmt.Errorf("DCR config is nil")
}
if logger == nil {
logger = GetSingletonNoOpLogger()
}
backend := config.StorageBackend
if backend == "" {
backend = string(DCRStorageBackendAuto) // Default to auto selection
}
switch DCRStorageBackend(backend) {
case DCRStorageBackendFile:
logger.Info("Using file-based storage for DCR credentials")
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
case DCRStorageBackendRedis:
cache := getDCRCache(cacheManager)
if cache == nil {
return nil, fmt.Errorf("redis storage requested but Redis/cache not configured")
}
logger.Info("Using Redis storage for DCR credentials")
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
case DCRStorageBackendAuto:
// Try Redis first, fallback to file
cache := getDCRCache(cacheManager)
if cache != nil && cache.backend != nil {
logger.Info("Auto-selected Redis storage for DCR credentials")
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
}
logger.Info("Redis not available, using file storage for DCR credentials")
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
default:
return nil, fmt.Errorf("unknown DCR storage backend: %s", backend)
}
}
// getDCRCache safely retrieves the DCR credentials cache from the cache manager
func getDCRCache(cacheManager *CacheManager) *UniversalCache {
if cacheManager == nil {
return nil
}
cacheManager.mu.RLock()
defer cacheManager.mu.RUnlock()
if cacheManager.manager == nil {
return nil
}
return cacheManager.manager.GetDCRCredentialsCache()
}
+663
View File
@@ -0,0 +1,663 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
// TestFileCredentialsStore_SaveLoad tests the file-based credentials store
func TestFileCredentialsStore_SaveLoad(t *testing.T) {
t.Parallel()
// Create a temp directory for test files
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
testCreds := &ClientRegistrationResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "test-access-token",
RegistrationClientURI: "https://example.com/register/test-client-id",
RedirectURIs: []string{"https://app.example.com/callback"},
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "client_secret_basic",
}
ctx := context.Background()
providerURL := "https://auth.example.com"
t.Run("save and load credentials", func(t *testing.T) {
// Save credentials
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
// Load credentials
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
// Verify fields
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
tempDir2 := t.TempDir()
store2 := NewFileCredentialsStore(filepath.Join(tempDir2, "nonexistent.json"), logger)
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent file: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if exists {
t.Error("Expected credentials to not exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("delete non-existent credentials", func(t *testing.T) {
// Should not error
err := store.Delete(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Delete should not error for non-existent: %v", err)
}
})
}
// TestFileCredentialsStore_MultiProvider tests multi-provider support
func TestFileCredentialsStore_MultiProvider(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
provider1 := "https://auth1.example.com"
provider2 := "https://auth2.example.com"
creds1 := &ClientRegistrationResponse{
ClientID: "client-1",
ClientSecret: "secret-1",
}
creds2 := &ClientRegistrationResponse{
ClientID: "client-2",
ClientSecret: "secret-2",
}
// Save credentials for both providers
if err := store.Save(ctx, provider1, creds1); err != nil {
t.Fatalf("Failed to save creds1: %v", err)
}
if err := store.Save(ctx, provider2, creds2); err != nil {
t.Fatalf("Failed to save creds2: %v", err)
}
// Load and verify each provider's credentials
loaded1, err := store.Load(ctx, provider1)
if err != nil {
t.Fatalf("Failed to load creds1: %v", err)
}
if loaded1.ClientID != "client-1" {
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
}
loaded2, err := store.Load(ctx, provider2)
if err != nil {
t.Fatalf("Failed to load creds2: %v", err)
}
if loaded2.ClientID != "client-2" {
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
}
// Delete one shouldn't affect the other
if err := store.Delete(ctx, provider1); err != nil {
t.Fatalf("Failed to delete creds1: %v", err)
}
exists, _ := store.Exists(ctx, provider2)
if !exists {
t.Error("Provider 2 credentials should still exist")
}
}
// TestFileCredentialsStore_ConcurrentAccess tests thread safety
func TestFileCredentialsStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
creds := &ClientRegistrationResponse{
ClientID: "test-client",
ClientSecret: "test-secret",
}
var wg sync.WaitGroup
concurrency := 10
// Concurrent saves
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = store.Save(ctx, providerURL, creds)
}()
}
wg.Wait()
// Concurrent loads
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = store.Load(ctx, providerURL)
}()
}
wg.Wait()
// Final verification
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load after concurrent access: %v", err)
}
if loaded == nil || loaded.ClientID != "test-client" {
t.Error("Credentials corrupted after concurrent access")
}
}
// TestFileCredentialsStore_InvalidInput tests error handling
func TestFileCredentialsStore_InvalidInput(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
t.Run("empty provider URL uses default path", func(t *testing.T) {
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "", creds)
if err != nil {
t.Fatalf("Save with empty provider URL failed: %v", err)
}
loaded, err := store.Load(ctx, "")
if err != nil {
t.Fatalf("Load with empty provider URL failed: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials with empty provider URL")
}
})
}
// TestFileCredentialsStore_DefaultPath tests default path behavior
func TestFileCredentialsStore_DefaultPath(t *testing.T) {
t.Parallel()
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore("", logger)
// Just verify we can create with empty path and it has a default
if store.basePath() == "" {
t.Error("Expected default base path")
}
}
// TestRedisCredentialsStore_WithMemoryCache tests Redis store with in-memory cache
func TestRedisCredentialsStore_WithMemoryCache(t *testing.T) {
t.Parallel()
// Create an in-memory cache for testing
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
testCreds := &ClientRegistrationResponse{
ClientID: "redis-test-client",
ClientSecret: "redis-test-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "redis-test-token",
RedirectURIs: []string{"https://app.example.com/callback"},
}
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
}
// TestRedisCredentialsStore_TTLFromExpiry tests TTL calculation
func TestRedisCredentialsStore_TTLFromExpiry(t *testing.T) {
t.Parallel()
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
t.Run("expired credentials should fail", func(t *testing.T) {
expiredCreds := &ClientRegistrationResponse{
ClientID: "expired-client",
ClientSecret: "expired-secret",
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), // Already expired
}
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
if err == nil {
t.Error("Expected error for expired credentials")
}
})
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
creds := &ClientRegistrationResponse{
ClientID: "no-expiry-client",
ClientSecret: "no-expiry-secret",
ClientSecretExpiresAt: 0, // No expiry
}
err := store.Save(ctx, "https://noexpiry.example.com", creds)
if err != nil {
t.Fatalf("Failed to save credentials without expiry: %v", err)
}
})
}
// TestRedisCredentialsStore_InvalidInput tests error handling
func TestRedisCredentialsStore_InvalidInput(t *testing.T) {
t.Parallel()
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
}
// TestDCRStorageFactory tests the factory function
func TestDCRStorageFactory(t *testing.T) {
t.Parallel()
logger := GetSingletonNoOpLogger()
t.Run("nil config returns error", func(t *testing.T) {
_, err := NewDCRCredentialsStore(nil, nil, logger)
if err == nil {
t.Error("Expected error for nil config")
}
})
t.Run("file backend creates file store", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "file",
CredentialsFile: "/tmp/test-creds.json",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create file store: %v", err)
}
if store == nil {
t.Error("Expected store but got nil")
}
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore")
}
})
t.Run("redis backend without cache manager returns error", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "redis",
}
_, err := NewDCRCredentialsStore(config, nil, logger)
if err == nil {
t.Error("Expected error for redis backend without cache manager")
}
})
t.Run("auto backend without redis falls back to file", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "auto",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create auto store: %v", err)
}
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore for auto without redis")
}
})
t.Run("unknown backend returns error", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "unknown",
}
_, err := NewDCRCredentialsStore(config, nil, logger)
if err == nil {
t.Error("Expected error for unknown backend")
}
})
t.Run("empty backend defaults to auto", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create store with empty backend: %v", err)
}
// Should default to file (auto without redis)
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore for empty backend")
}
})
}
// TestDynamicClientRegistrar_WithStore tests registrar with store
func TestDynamicClientRegistrar_WithStore(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
}
registrar := NewDynamicClientRegistrarWithStore(
nil, // httpClient
logger,
config,
"https://auth.example.com",
store,
)
if registrar == nil {
t.Fatal("Expected registrar but got nil")
}
if registrar.store == nil {
t.Error("Expected store to be set")
}
// Test SetStore
newStore := NewFileCredentialsStore(filepath.Join(tempDir, "new.json"), logger)
registrar.SetStore(newStore)
if registrar.store != newStore {
t.Error("SetStore did not update the store")
}
}
// TestDynamicClientRegistrar_CredentialsFromStore tests loading from store
func TestDynamicClientRegistrar_CredentialsFromStore(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
providerURL := "https://auth.example.com"
ctx := context.Background()
// Pre-save credentials
testCreds := &ClientRegistrationResponse{
ClientID: "pre-saved-client",
ClientSecret: "pre-saved-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
}
if err := store.Save(ctx, providerURL, testCreds); err != nil {
t.Fatalf("Failed to pre-save credentials: %v", err)
}
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
}
registrar := NewDynamicClientRegistrarWithStore(
nil,
logger,
config,
providerURL,
store,
)
// Test loading via the internal method
loaded, err := registrar.loadCredentialsFromStore(ctx)
if err != nil {
t.Fatalf("Failed to load from store: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != "pre-saved-client" {
t.Errorf("ClientID mismatch: got %s", loaded.ClientID)
}
}
// TestFileCredentialsStore_CorruptedFile tests handling of corrupted files
func TestFileCredentialsStore_CorruptedFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
// Write corrupted JSON
filePath := store.getFilePath(providerURL)
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
t.Fatalf("Failed to write corrupted file: %v", err)
}
// Should return error for corrupted file
_, err := store.Load(ctx, providerURL)
if err == nil {
t.Error("Expected error for corrupted JSON")
}
}
// TestFileCredentialsStore_DirectoryCreation tests auto directory creation
func TestFileCredentialsStore_DirectoryCreation(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(deepPath, logger)
ctx := context.Background()
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "https://example.com", creds)
if err != nil {
t.Fatalf("Failed to save with nested directory: %v", err)
}
loaded, err := store.Load(ctx, "https://example.com")
if err != nil {
t.Fatalf("Failed to load after nested directory creation: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials from nested directory")
}
}
+34 -1
View File
@@ -384,10 +384,14 @@ scopes:
### Dynamic Client Registration (RFC 7591)
Dynamic Client Registration allows the middleware to automatically register itself with the OIDC provider, eliminating the need to manually create client credentials.
**Basic Configuration (Single Instance):**
```yaml
dynamicClientRegistration:
enabled: true
initialAccessToken: "your-token" # Optional
initialAccessToken: "your-token" # Optional, if provider requires it
persistCredentials: true
credentialsFile: "/tmp/oidc-credentials.json"
clientMetadata:
@@ -400,6 +404,35 @@ dynamicClientRegistration:
- "refresh_token"
```
**Multi-Replica Deployment (Kubernetes):**
For Kubernetes deployments with multiple replicas, use Redis storage to share credentials across all instances and prevent registration race conditions:
```yaml
dynamicClientRegistration:
enabled: true
persistCredentials: true
storageBackend: "redis" # Share credentials via Redis
redisKeyPrefix: "myapp:dcr:" # Optional custom prefix
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
redis:
enabled: true
address: "redis:6379"
cacheMode: "redis"
```
**Storage Backend Options:**
| Backend | Description | Use Case |
|---------|-------------|----------|
| `file` | Store credentials in local file | Single instance deployments |
| `redis` | Store credentials in Redis | Multi-replica Kubernetes deployments |
| `auto` | Use Redis if available, fallback to file | Flexible deployments (default) |
### Multi-Replica Deployment
Without Redis, disable replay detection:
+2
View File
@@ -353,6 +353,8 @@ allowPrivateIPAddresses: true # Required for private IPs
- Roles: User Client Role mapper with "Add to ID token" enabled
- Groups: Group Membership mapper with "Add to ID token" enabled
See [KEYCLOAK_SETUP_GUIDE.md](KEYCLOAK_SETUP_GUIDE.md) for detailed step-by-step setup instructions, mapper configuration, troubleshooting, and performance optimization.
---
## AWS Cognito
+110 -1
View File
@@ -90,6 +90,7 @@
<a href="#configuration" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Configuration</a>
<a href="#deployment" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Deployment</a>
<a href="#security" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Security</a>
<a href="#logout" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Logout</a>
</div>
<div class="flex items-center space-x-4">
<button id="theme-toggle" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 p-2 min-w-[44px] min-h-[44px] flex items-center justify-center" aria-label="Toggle theme">
@@ -114,6 +115,7 @@
<a href="#configuration" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Configuration</a>
<a href="#deployment" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Deployment</a>
<a href="#security" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Security</a>
<a href="#logout" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Logout</a>
</div>
</div>
</nav>
@@ -193,7 +195,7 @@
</div>
<div>
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-1">Dynamic Registration</h3>
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration for automatic client setup without manual configuration</p>
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration with Redis storage support for multi-replica deployments</p>
</div>
</div>
</div>
@@ -862,6 +864,48 @@ spec:
</table>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Dynamic Client Registration (RFC 7591)</h3>
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Automatically register your application with the OIDC provider. Supports Redis storage for multi-replica deployments:</p>
<div class="overflow-x-auto mb-4">
<table class="w-full text-sm">
<thead>
<tr class="border-b border-gray-200 dark:border-gray-700">
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Parameter</th>
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Default</th>
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Description</th>
</tr>
</thead>
<tbody class="text-gray-600 dark:text-gray-400">
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.enabled</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Enable dynamic client registration</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.persistCredentials</code></td>
<td class="py-2 px-3">true</td>
<td class="py-2 px-3">Persist registered credentials across restarts</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.storageBackend</code></td>
<td class="py-2 px-3">auto</td>
<td class="py-2 px-3">Storage backend: <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">file</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis</code>, or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">auto</code> (uses Redis if available)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.redisKeyPrefix</code></td>
<td class="py-2 px-3">dcr:creds:</td>
<td class="py-2 px-3">Redis key prefix for DCR credentials</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.clientMetadata.redirect_uris</code></td>
<td class="py-2 px-3">-</td>
<td class="py-2 px-3">Redirect URIs for the registered client (required)</td>
</tr>
</tbody>
</table>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Example: Security Headers with CORS</h3>
@@ -1177,6 +1221,71 @@ spec:
</div>
</section>
<!-- IdP-Initiated Logout Section -->
<section id="logout" class="py-12 sm:py-16 md:py-20 bg-white dark:bg-gray-900 theme-transition">
<div class="max-w-6xl mx-auto px-4 sm:px-6">
<div class="text-center mb-8 sm:mb-12">
<h2 class="text-2xl sm:text-3xl md:text-4xl font-bold text-gray-900 dark:text-gray-100 mb-3 sm:mb-4">IdP-Initiated Logout</h2>
<p class="text-base sm:text-lg text-gray-600 dark:text-gray-300 px-4">Support for OIDC Back-Channel and Front-Channel Logout specifications</p>
</div>
<div class="max-w-4xl mx-auto">
<div class="grid md:grid-cols-2 gap-6 mb-8">
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
<i class="fas fa-server mr-2 text-blue-500"></i>
Back-Channel Logout
</h3>
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
Server-to-server logout notification. The IdP sends a signed JWT (logout_token) directly to your application when a user logs out.
</p>
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
<li>&#8226; Signed JWT logout tokens</li>
<li>&#8226; Session ID (sid) based invalidation</li>
<li>&#8226; Subject (sub) based invalidation</li>
<li>&#8226; Works behind firewalls</li>
</ul>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
<i class="fas fa-browser mr-2 text-purple-500"></i>
Front-Channel Logout
</h3>
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
Browser-based logout via iframe. The IdP embeds an iframe pointing to your logout endpoint during user logout.
</p>
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
<li>&#8226; Iframe-based session termination</li>
<li>&#8226; Immediate cookie invalidation</li>
<li>&#8226; Simple GET request handling</li>
<li>&#8226; Issuer validation</li>
</ul>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Configuration Example</h3>
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
# ... other OIDC configuration ...
# Back-Channel Logout (server-to-server)
enableBackchannelLogout: true
backchannelLogoutURL: "/backchannel-logout"
# Front-Channel Logout (browser-based)
enableFrontchannelLogout: true
frontchannelLogoutURL: "/frontchannel-logout"</code></pre>
<p class="text-gray-600 dark:text-gray-400 text-sm mt-4">
Configure your IdP with the full URLs (e.g., <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">https://your-app.example.com/backchannel-logout</code>).
When a user logs out from the IdP, all their sessions across your applications will be invalidated.
</p>
</div>
</div>
</div>
</section>
<!-- Why Choose Section -->
<section class="py-12 sm:py-16 md:py-20 bg-gray-50 dark:bg-gray-800 theme-transition">
<div class="max-w-6xl mx-auto px-4 sm:px-6">
+103 -45
View File
@@ -16,35 +16,26 @@ import (
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
type ClientRegistrationResponse struct {
// Required fields
ClientID string `json:"client_id"`
// Conditional - only for confidential clients
ClientSecret string `json:"client_secret,omitempty"`
// Optional - for managing registration
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
// Expiration
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
// Echo back of registered metadata
RedirectURIs []string `json:"redirect_uris,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
Contacts []string `json:"contacts,omitempty"`
ClientName string `json:"client_name,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
TOSURI string `json:"tos_uri,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
SubjectType string `json:"subject_type,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
Scope string `json:"scope,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
TOSURI string `json:"tos_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
ClientID string `json:"client_id"`
ClientName string `json:"client_name,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
}
// ClientRegistrationError represents an error response from client registration (RFC 7591)
@@ -55,14 +46,13 @@ type ClientRegistrationError struct {
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
type DynamicClientRegistrar struct {
httpClient *http.Client
logger *Logger
config *DynamicClientRegistrationConfig
providerURL string
// Cached registration response
mu sync.RWMutex
httpClient *http.Client
logger *Logger
config *DynamicClientRegistrationConfig
registrationResponse *ClientRegistrationResponse
store DCRCredentialsStore // Storage backend for credentials
providerURL string
mu sync.RWMutex
}
// NewDynamicClientRegistrar creates a new dynamic client registrar
@@ -84,8 +74,37 @@ func NewDynamicClientRegistrar(
}
}
// NewDynamicClientRegistrarWithStore creates a new dynamic client registrar with a specific storage backend
func NewDynamicClientRegistrarWithStore(
httpClient *http.Client,
logger *Logger,
dcrConfig *DynamicClientRegistrationConfig,
providerURL string,
store DCRCredentialsStore,
) *DynamicClientRegistrar {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &DynamicClientRegistrar{
httpClient: httpClient,
logger: logger,
config: dcrConfig,
providerURL: providerURL,
store: store,
}
}
// SetStore sets the credentials store for the registrar
// This allows setting the store after creation when the cache manager is available
func (r *DynamicClientRegistrar) SetStore(store DCRCredentialsStore) {
r.mu.Lock()
defer r.mu.Unlock()
r.store = store
}
// RegisterClient performs dynamic client registration with the OIDC provider
// It first attempts to load existing credentials from a file if persistence is enabled,
// It first attempts to load existing credentials from storage if persistence is enabled,
// then registers a new client if no valid credentials exist.
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
if r.config == nil || !r.config.Enabled {
@@ -94,10 +113,13 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
// Try to load existing credentials if persistence is enabled
if r.config.PersistCredentials {
if resp, err := r.loadCredentials(); err == nil && resp != nil {
resp, err := r.loadCredentialsFromStore(ctx)
if err != nil {
r.logger.Debugf("Failed to load credentials from store: %v", err)
} else if resp != nil {
// Check if credentials are still valid (not expired)
if r.areCredentialsValid(resp) {
r.logger.Info("Loaded existing client credentials from file")
r.logger.Info("Loaded existing client credentials from storage")
r.mu.Lock()
r.registrationResponse = resp
r.mu.Unlock()
@@ -190,7 +212,7 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
// Persist credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist client credentials: %v", err)
// Don't fail registration if persistence fails
}
@@ -326,7 +348,44 @@ func (r *DynamicClientRegistrar) credentialsFilePath() string {
return "/tmp/oidc-client-credentials.json"
}
// saveCredentials persists client credentials to a file
// loadCredentialsFromStore loads client credentials from the configured storage backend
// Falls back to legacy file-based loading if no store is configured
func (r *DynamicClientRegistrar) loadCredentialsFromStore(ctx context.Context) (*ClientRegistrationResponse, error) {
// Use store if available
if r.store != nil {
return r.store.Load(ctx, r.providerURL)
}
// Fallback to legacy file-based loading
return r.loadCredentials()
}
// saveCredentialsToStore persists client credentials to the configured storage backend
// Falls back to legacy file-based saving if no store is configured
func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, resp *ClientRegistrationResponse) error {
// Use store if available
if r.store != nil {
return r.store.Save(ctx, r.providerURL, resp)
}
// Fallback to legacy file-based saving
return r.saveCredentials(resp)
}
// deleteCredentialsFromStore removes credentials from the configured storage backend
// Falls back to legacy file-based deletion if no store is configured
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
// Use store if available
if r.store != nil {
return r.store.Delete(ctx, r.providerURL)
}
// Fallback to legacy file-based deletion
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// saveCredentials persists client credentials to a file (legacy method)
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
filePath := r.credentialsFilePath()
@@ -344,7 +403,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
return nil
}
// loadCredentials loads client credentials from a file
// loadCredentials loads client credentials from a file (legacy method)
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
filePath := r.credentialsFilePath()
@@ -431,7 +490,7 @@ func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (
// Persist updated credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist updated credentials: %v", err)
}
}
@@ -538,11 +597,10 @@ func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) e
r.registrationResponse = nil
r.mu.Unlock()
// Remove credentials file if persistence is enabled
// Remove credentials from storage if persistence is enabled
if r.config.PersistCredentials {
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
r.logger.Errorf("Failed to remove credentials file: %v", err)
if err := r.deleteCredentialsFromStore(ctx); err != nil {
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
}
}
+5 -5
View File
@@ -223,10 +223,10 @@ func TestRegisterClientWithInitialAccessToken(t *testing.T) {
// TestRegisterClientError tests error handling during registration
func TestRegisterClientError(t *testing.T) {
tests := []struct {
name string
serverResponse func(w http.ResponseWriter, r *http.Request)
expectError bool
name string
errorContains string
expectError bool
}{
{
name: "invalid_redirect_uri error",
@@ -321,8 +321,8 @@ func TestRegisterClientError(t *testing.T) {
// TestRegisterClientDisabled tests that registration fails when not enabled
func TestRegisterClientDisabled(t *testing.T) {
tests := []struct {
name string
dcrConfig *DynamicClientRegistrationConfig
name string
}{
{
name: "nil config",
@@ -521,8 +521,8 @@ func TestCredentialsValidation(t *testing.T) {
registrar := NewDynamicClientRegistrar(&http.Client{}, NewLogger("DEBUG"), dcrConfig, "https://example.com")
tests := []struct {
name string
response *ClientRegistrationResponse
name string
expected bool
}{
{
@@ -584,9 +584,9 @@ func TestCredentialsValidation(t *testing.T) {
// TestBuildRegistrationRequest tests the request body construction
func TestBuildRegistrationRequest(t *testing.T) {
tests := []struct {
name string
metadata *ClientRegistrationMetadata
expectedFields map[string]interface{}
name string
expectError bool
}{
{
+29 -47
View File
@@ -12,23 +12,19 @@ import (
// EnhancedMockJWKCache is an improved state-based mock with call tracking
type EnhancedMockJWKCache struct {
mu sync.RWMutex
// State (what to return)
JWKS *JWKSet
Err error
// Call tracking
Err error
JWKS *JWKSet
GetJWKSCalls []JWKSCall
mu sync.RWMutex
getJWKSCallsMu sync.Mutex
CleanupCalls int32
CloseCalls int32
getJWKSCallsMu sync.Mutex
}
// JWKSCall records parameters from a GetJWKS call
type JWKSCall struct {
URL string
Timestamp time.Time
URL string
}
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
@@ -108,22 +104,18 @@ func (m *EnhancedMockJWKCache) Reset() {
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
type EnhancedMockTokenVerifier struct {
mu sync.RWMutex
// State (what to return) - can be a fixed error or a function
Err error
VerifyFunc func(token string) error
// Call tracking
Err error
VerifyFunc func(token string) error
VerifyCalls []TokenVerifyCall
mu sync.RWMutex
verifyCallsMu sync.Mutex
}
// TokenVerifyCall records parameters from a VerifyToken call
type TokenVerifyCall struct {
Token string
Timestamp time.Time
Result error
Token string
}
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
@@ -207,49 +199,43 @@ func (m *EnhancedMockTokenVerifier) Reset() {
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
type EnhancedMockTokenExchanger struct {
mu sync.RWMutex
// State (what to return)
ExchangeResponse *TokenResponse
ExchangeErr error
RefreshResponse *TokenResponse
RefreshErr error
RevokeErr error
// Optional functions for dynamic behavior
ExchangeErr error
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
RefreshResponse *TokenResponse
ExchangeResponse *TokenResponse
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
RevokeTokenFunc func(token, tokenType string) error
// Call tracking
ExchangeCalls []ExchangeCall
RefreshCalls []RefreshCall
RevokeCalls []RevokeCall
exchangeCallsMu sync.Mutex
refreshCallsMu sync.Mutex
revokeCallsMu sync.Mutex
ExchangeCalls []ExchangeCall
RefreshCalls []RefreshCall
RevokeCalls []RevokeCall
mu sync.RWMutex
exchangeCallsMu sync.Mutex
refreshCallsMu sync.Mutex
revokeCallsMu sync.Mutex
}
// ExchangeCall records parameters from an ExchangeCodeForToken call
type ExchangeCall struct {
Timestamp time.Time
GrantType string
CodeOrToken string
RedirectURL string
CodeVerifier string
Timestamp time.Time
}
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
type RefreshCall struct {
RefreshToken string
Timestamp time.Time
RefreshToken string
}
// RevokeCall records parameters from a RevokeTokenWithProvider call
type RevokeCall struct {
Timestamp time.Time
Token string
TokenType string
Timestamp time.Time
}
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
@@ -401,16 +387,12 @@ func (m *EnhancedMockTokenExchanger) Reset() {
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
type EnhancedMockCacheInterface struct {
mu sync.RWMutex
// Internal storage
data map[string]cacheEntry
maxSize int
// Call tracking
data map[string]cacheEntry
GetCalls []CacheGetCall
SetCalls []CacheSetCall
DeleteCalls []string
maxSize int
mu sync.RWMutex
getCalls sync.Mutex
setCalls sync.Mutex
deleteCalls sync.Mutex
@@ -423,17 +405,17 @@ type cacheEntry struct {
// CacheGetCall records parameters from a Get call
type CacheGetCall struct {
Timestamp time.Time
Key string
Found bool
Timestamp time.Time
}
// CacheSetCall records parameters from a Set call
type CacheSetCall struct {
Key string
Value any
TTL time.Duration
Timestamp time.Time
Value any
Key string
TTL time.Duration
}
// NewEnhancedMockCache creates a new enhanced cache mock
+16 -37
View File
@@ -642,14 +642,10 @@ func (e *HTTPError) Error() string {
// OIDCError represents OIDC-specific errors with context information.
// It provides structured error reporting for authentication and authorization failures.
type OIDCError struct {
// Code identifies the specific error type
Code string
// Message provides a human-readable description
Message string
// Context contains additional error context (e.g., provider, session details)
Cause error
Context map[string]interface{}
// Cause is the underlying error that caused this error
Cause error
Code string
Message string
}
// Error returns the string representation of the OIDC error.
@@ -669,14 +665,10 @@ func (e *OIDCError) Unwrap() error {
// SessionError represents session-related errors with context.
// Used for session management, validation, and storage errors.
type SessionError struct {
// Operation describes what session operation failed
Cause error
Operation string
// Message provides a human-readable description
Message string
// SessionID identifies the session (if available)
Message string
SessionID string
// Cause is the underlying error that caused this error
Cause error
}
// Error returns the string representation of the session error.
@@ -696,14 +688,10 @@ func (e *SessionError) Unwrap() error {
// TokenError represents token-related errors with validation context.
// Used for JWT validation, token refresh, and token format errors.
type TokenError struct {
// TokenType identifies the type of token (id_token, access_token, refresh_token)
Cause error
TokenType string
// Reason describes why the token is invalid
Reason string
// Message provides a human-readable description
Message string
// Cause is the underlying error that caused this error
Cause error
Reason string
Message string
}
// Error returns the string representation of the token error.
@@ -765,24 +753,15 @@ func NewTokenError(tokenType, reason, message string, cause error) *TokenError {
// It provides fallback mechanisms when primary services are unavailable and monitors
// service health to automatically recover when services become available again.
type GracefulDegradation struct {
// BaseRecoveryMechanism provides common functionality
*BaseRecoveryMechanism
// fallbacks stores service-specific fallback implementations
fallbacks map[string]func() (interface{}, error)
// healthChecks stores service health check functions
healthChecks map[string]func() bool
// degradedServices tracks which services are currently degraded
fallbacks map[string]func() (interface{}, error)
healthChecks map[string]func() bool
degradedServices map[string]time.Time
// config contains graceful degradation configuration
config GracefulDegradationConfig
// mutex protects shared state
mutex sync.RWMutex
// healthCheckTask manages background health checking
healthCheckTask *BackgroundTask
// stopChan signals shutdown
stopChan chan struct{}
// shutdownOnce ensures shutdown happens only once
shutdownOnce sync.Once
healthCheckTask *BackgroundTask
stopChan chan struct{}
config GracefulDegradationConfig
mutex sync.RWMutex
shutdownOnce sync.Once
}
// GracefulDegradationConfig holds configuration for graceful degradation behavior.
@@ -975,7 +954,7 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
gd.mutex.RLock()
defer gd.mutex.RUnlock()
var degraded []string
degraded := make([]string, 0, len(gd.degradedServices))
for serviceName := range gd.degradedServices {
degraded = append(degraded, serviceName)
}
+9 -9
View File
@@ -20,10 +20,10 @@ import (
func TestCircuitBreakerStateTransitions(t *testing.T) {
tests := []struct {
name string
failures int
maxFailures int
expectedStateBefore string
expectedStateAfter string
failures int
maxFailures int
}{
{
name: "stays closed below threshold",
@@ -543,8 +543,8 @@ func TestRetryExecutorNetworkErrors(t *testing.T) {
}, nil)
tests := []struct {
name string
err error
name string
shouldRetry bool
}{
{
@@ -1647,8 +1647,8 @@ func TestGracefulDegradationFullScenario(t *testing.T) {
func TestIsTraefikDefaultCertError(t *testing.T) {
tests := []struct {
name string
err error
name string
expected bool
}{
{
@@ -1680,8 +1680,8 @@ func TestIsTraefikDefaultCertError(t *testing.T) {
func TestIsEOFError(t *testing.T) {
tests := []struct {
name string
err error
name string
expected bool
}{
{
@@ -1723,8 +1723,8 @@ func TestIsEOFError(t *testing.T) {
func TestIsCertificateError(t *testing.T) {
tests := []struct {
name string
err error
name string
expected bool
}{
{
@@ -1811,8 +1811,8 @@ func TestRetryExecutorStartupErrors(t *testing.T) {
_ = NewRetryExecutor(MetadataFetchRetryConfig(), nil)
tests := []struct {
name string
err error
name string
shouldRetry bool
}{
{
@@ -1890,8 +1890,8 @@ func TestRetryExecutorIsRetryableErrorIntegration(t *testing.T) {
re := NewRetryExecutor(DefaultRetryConfig(), nil)
tests := []struct {
name string
err error
name string
shouldRetry bool
}{
{
@@ -1977,9 +1977,9 @@ func circuitBreakerStateToString(state CircuitBreakerState) string {
}
type mockNetError struct {
msg string
timeout bool
temporary bool
msg string
}
func (e *mockNetError) Error() string { return e.msg }
+6 -6
View File
@@ -10,16 +10,16 @@ import (
type GoroutineManager struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
goroutines map[string]*managedGoroutine
logger *Logger
wg sync.WaitGroup
mu sync.RWMutex
}
type managedGoroutine struct {
name string
cancel context.CancelFunc
startTime time.Time
cancel context.CancelFunc
name string
running bool
}
@@ -149,10 +149,10 @@ func (m *GoroutineManager) GetStatus() map[string]GoroutineStatus {
// GoroutineStatus represents the status of a managed goroutine
type GoroutineStatus struct {
Name string
Running bool
StartTime time.Time
Name string
Runtime time.Duration
Running bool
}
// ErrShutdownTimeout is returned when shutdown times out
+1
View File
@@ -336,6 +336,7 @@ func createStringMap(keys []string) map[string]struct{} {
// and redirects to the provider's logout endpoint or configured post-logout URI.
// It handles potential errors during session retrieval or clearing.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing logout request")
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
+12 -19
View File
@@ -12,30 +12,23 @@ import (
// HTTPClientConfig provides configuration for creating HTTP clients
type HTTPClientConfig struct {
// Timeout for the entire request
Timeout time.Duration
// MaxRedirects allowed (0 means follow Go's default of 10)
MaxRedirects int
// UseCookieJar enables cookie jar for the client
UseCookieJar bool
// Connection settings
IdleConnTimeout time.Duration
MaxIdleConns int
ReadBufferSize int
DialTimeout time.Duration
KeepAlive time.Duration
TLSHandshakeTimeout time.Duration
ResponseHeaderTimeout time.Duration
ExpectContinueTimeout time.Duration
IdleConnTimeout time.Duration
// Connection pool settings
MaxIdleConns int
MaxIdleConnsPerHost int
MaxConnsPerHost int
// Buffer settings
WriteBufferSize int
ReadBufferSize int
// Feature flags
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
MaxRedirects int
MaxIdleConnsPerHost int
Timeout time.Duration
MaxConnsPerHost int
WriteBufferSize int
UseCookieJar bool
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
}
// DefaultHTTPClientConfig returns the default configuration for general use
+1 -1
View File
@@ -110,9 +110,9 @@ func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
tests := []struct {
name string
errorMsg string
config HTTPClientConfig
wantError bool
errorMsg string
}{
{
name: "valid config",
+6 -6
View File
@@ -12,19 +12,19 @@ import (
// SharedTransportPool manages a pool of shared HTTP transports to prevent connection exhaustion
type SharedTransportPool struct {
mu sync.RWMutex
transports map[string]*sharedTransport
maxConns int
ctx context.Context
transports map[string]*sharedTransport
cancel context.CancelFunc
clientCount int32 // SECURITY FIX: Track total HTTP clients
maxClients int32 // SECURITY FIX: Limit total clients to 5
maxConns int
mu sync.RWMutex
clientCount int32
maxClients int32
}
type sharedTransport struct {
lastUsed time.Time
transport *http.Transport
refCount int
lastUsed time.Time
}
var (
+14 -27
View File
@@ -10,6 +10,14 @@ import (
"unicode/utf8"
)
// Pre-compiled regex patterns for validation (const patterns should use MustCompile)
var (
emailRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
urlRegexPattern = regexp.MustCompile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
tokenRegexPattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)
usernameRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
)
// InputValidator provides comprehensive input validation and sanitization
// to protect against common security vulnerabilities including SQL injection,
// XSS, path traversal, and other injection attacks. It validates and sanitizes
@@ -73,7 +81,7 @@ func DefaultInputValidationConfig() InputValidationConfig {
}
// NewInputValidator creates a new input validator with the specified configuration.
// It compiles all necessary regex patterns and initializes security pattern lists.
// It uses pre-compiled regex patterns and initializes security pattern lists.
//
// Parameters:
// - config: Validation configuration with size limits and mode settings.
@@ -81,29 +89,8 @@ func DefaultInputValidationConfig() InputValidationConfig {
//
// Returns:
// - A configured InputValidator instance.
// - An error if regex compilation fails.
// - An error (always nil, kept for API compatibility).
func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputValidator, error) {
// Compile regex patterns
emailRegex, err := regexp.Compile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if err != nil {
return nil, fmt.Errorf("failed to compile email regex: %w", err)
}
urlRegex, err := regexp.Compile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
if err != nil {
return nil, fmt.Errorf("failed to compile URL regex: %w", err)
}
tokenRegex, err := regexp.Compile(`^[A-Za-z0-9._-]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile token regex: %w", err)
}
usernameRegex, err := regexp.Compile(`^[a-zA-Z0-9._-]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile username regex: %w", err)
}
return &InputValidator{
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
@@ -112,10 +99,10 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
emailRegex: emailRegexPattern,
urlRegex: urlRegexPattern,
tokenRegex: tokenRegexPattern,
usernameRegex: usernameRegexPattern,
sqlInjectionPatterns: []string{
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
"union", "select", "insert", "update", "delete", "drop",
+9 -9
View File
@@ -14,7 +14,7 @@ func TestInputValidator(t *testing.T) {
}
t.Run("Valid token validation", func(t *testing.T) {
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc"
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc" // trufflehog:ignore
result := validator.ValidateToken(validToken)
if !result.IsValid {
@@ -428,12 +428,12 @@ func TestInputValidatorValidateToken(t *testing.T) {
tests := []struct {
name string
token string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidJWTToken",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature", // trufflehog:ignore
expectValid: true,
description: "Valid JWT token should pass validation",
},
@@ -475,7 +475,7 @@ func TestInputValidatorValidateToken(t *testing.T) {
},
{
name: "MaliciousJWTWithExtraData",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra",
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra", // trufflehog:ignore
expectValid: false,
description: "JWT with extra malicious data should fail validation",
},
@@ -500,8 +500,8 @@ func TestInputValidatorValidateEmail(t *testing.T) {
tests := []struct {
name string
email string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidEmail",
@@ -578,8 +578,8 @@ func TestInputValidatorValidateURL(t *testing.T) {
tests := []struct {
name string
url string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidHTTPSURL",
@@ -669,8 +669,8 @@ func TestInputValidatorValidateClaim(t *testing.T) {
name string
claimName string
claimValue string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidStringClaim",
@@ -750,8 +750,8 @@ func TestInputValidatorValidateHeader(t *testing.T) {
name string
headerName string
headerValue string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidHeader",
@@ -830,8 +830,8 @@ func TestInputValidatorValidateUsername(t *testing.T) {
tests := []struct {
name string
username string
expectValid bool
description string
expectValid bool
}{
{
name: "ValidUsername",
+4 -4
View File
@@ -726,20 +726,20 @@ type MockConfig struct {
}
type MockSession struct {
id string
userID string
created time.Time
lastUsed time.Time
data map[string]interface{}
id string
userID string
}
type TestResult struct {
UserID int
StartTime time.Time
EndTime time.Time
Error error
UserID int
Duration time.Duration
Success bool
Error error
}
// ============================================================================
+14 -25
View File
@@ -18,33 +18,22 @@ const (
// Config provides common configuration for cache backends
type Config struct {
// Type specifies the backend type
Type BackendType
// Memory backend settings
MaxSize int
MaxMemoryBytes int64
CleanupInterval time.Duration
// Redis backend settings
RedisAddr string
RedisPassword string
RedisDB int
RedisPrefix string
PoolSize int
// Hybrid backend settings
L1Config *Config // Memory cache (L1)
L2Config *Config // Redis cache (L2)
AsyncWrites bool // Write to L2 asynchronously
// Resilience settings
L2Config *Config
L1Config *Config
RedisPrefix string
Type BackendType
RedisAddr string
RedisPassword string
PoolSize int
RedisDB int
CleanupInterval time.Duration
MaxMemoryBytes int64
MaxSize int
HealthCheckInterval time.Duration
AsyncWrites bool
EnableCircuitBreaker bool
EnableHealthCheck bool
HealthCheckInterval time.Duration
// Metrics
EnableMetrics bool
EnableMetrics bool
}
// DefaultConfig returns a default configuration for in-memory caching
+18 -28
View File
@@ -13,40 +13,30 @@ import (
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
// It provides automatic failover, async writes for non-critical data, and optimized read paths
type HybridBackend struct {
primary CacheBackend // L1: Memory cache for fast access
secondary CacheBackend // L2: Redis cache for distributed access
// Configuration
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
lastL2Error atomic.Value
secondary CacheBackend
primary CacheBackend
logger Logger
ctx context.Context
syncWriteCacheTypes map[string]bool
asyncWriteBuffer chan *asyncWriteItem
// Metrics
l1Hits atomic.Int64
l2Hits atomic.Int64
misses atomic.Int64
l1Writes atomic.Int64
l2Writes atomic.Int64
errors atomic.Int64
// Fallback tracking
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
lastL2Error atomic.Value // Stores last L2 error timestamp
// Lifecycle
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
// Logging
logger Logger
cancel context.CancelFunc
wg sync.WaitGroup
l1Hits atomic.Int64
errors atomic.Int64
l2Writes atomic.Int64
l1Writes atomic.Int64
misses atomic.Int64
l2Hits atomic.Int64
fallbackMode atomic.Bool
}
// asyncWriteItem represents an async write operation
type asyncWriteItem struct {
ctx context.Context
key string
value []byte
ttl time.Duration
ctx context.Context
}
// Logger interface for structured logging
@@ -82,9 +72,9 @@ func (l *defaultLogger) Errorf(format string, args ...interface{}) {
type HybridConfig struct {
Primary CacheBackend
Secondary CacheBackend
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
AsyncBufferSize int
Logger Logger
SyncWriteCacheTypes map[string]bool
AsyncBufferSize int
}
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
+6 -6
View File
@@ -17,23 +17,23 @@ import (
// mockBackend is a simple mock implementation of CacheBackend for testing
type mockBackend struct {
pingError error
data map[string]mockEntry
stats map[string]interface{}
mu sync.RWMutex
getCalls atomic.Int32
setCalls atomic.Int32
deleteCalls atomic.Int32
failSet bool
failGet bool
failDelete bool
failClear bool
failPing bool
pingError error
stats map[string]interface{}
getCalls atomic.Int32
setCalls atomic.Int32
deleteCalls atomic.Int32
}
type mockEntry struct {
value []byte
expiresAt time.Time
value []byte
}
// mockBatchBackend extends mockBackend with batch operations
+14 -45
View File
@@ -41,53 +41,22 @@ type CacheBackend interface {
// BackendStats represents statistics for a cache backend
type BackendStats struct {
// Type is the backend type
Type BackendType
// Hits is the number of cache hits
Hits int64
// Misses is the number of cache misses
Misses int64
// Sets is the number of set operations
Sets int64
// Deletes is the number of delete operations
Deletes int64
// Errors is the number of errors
Errors int64
// Evictions is the number of evicted items
Evictions int64
// CurrentSize is the current number of items in cache
CurrentSize int64
// MaxSize is the maximum number of items (0 means unlimited)
MaxSize int64
// MemoryUsage is the approximate memory usage in bytes
MemoryUsage int64
// AverageGetLatency is the average latency for get operations
StartTime time.Time
LastErrorTime time.Time
Type BackendType
LastError string
Deletes int64
Errors int64
Evictions int64
CurrentSize int64
MaxSize int64
MemoryUsage int64
AverageGetLatency time.Duration
// AverageSetLatency is the average latency for set operations
AverageSetLatency time.Duration
// LastError is the last error encountered
LastError string
// LastErrorTime is when the last error occurred
LastErrorTime time.Time
// Uptime is how long the backend has been running
Uptime time.Duration
// StartTime is when the backend was started
StartTime time.Time
Sets int64
Misses int64
Uptime time.Duration
Hits int64
}
// BackendCapabilities describes the capabilities of a cache backend
+219 -200
View File
@@ -2,23 +2,30 @@
package backends
import (
"container/list"
"context"
"sync"
"sync/atomic"
"time"
)
// Default configuration values
const (
defaultShardCount = 256
defaultMaxSize = int64(10000)
defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB
defaultCleanupInterval = 5 * time.Minute
)
// memoryCacheItem represents an item in the memory cache
type memoryCacheItem struct {
key string
value interface{}
expiresAt time.Time
createdAt time.Time
accessedAt time.Time
value interface{}
element interface{} // *list.Element, using interface{} to avoid import cycle
key string
accessCount int64
size int64
element *list.Element // for LRU tracking
}
// isExpired checks if the item is expired
@@ -29,17 +36,23 @@ func (item *memoryCacheItem) isExpired() bool {
return time.Now().After(item.expiresAt)
}
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage
// The sharded design reduces lock contention by partitioning keys across multiple shards,
// each with its own lock.
type MemoryCacheBackend struct {
mu sync.RWMutex
items map[string]*memoryCacheItem
lruList *list.List
maxSize int64
maxMemory int64
currentSize int64
currentMemory int64
shards []*cacheShard
startTime time.Time
lastErrorTime time.Time
cleanupDone chan struct{}
cleanupTicker *time.Ticker
lastError string
shardCount uint32
shardMask uint32
maxSize int64
maxMemory int64
cleanupInterval time.Duration
// Statistics
// Global stats (aggregated from shards)
hits atomic.Int64
misses atomic.Int64
sets atomic.Int64
@@ -53,40 +66,59 @@ type MemoryCacheBackend struct {
getCount atomic.Int64
setCount atomic.Int64
// Status
startTime time.Time
lastError string
lastErrorTime time.Time
cleanupTicker *time.Ticker
cleanupDone chan bool
closed atomic.Bool
// Configuration
cleanupInterval time.Duration
evictionPolicy string // "lru", "lfu", "fifo"
// State
closed atomic.Bool
mu sync.RWMutex // For global operations like stats and error tracking
}
// NewMemoryCacheBackend creates a new memory cache backend
// NewMemoryCacheBackend creates a new sharded memory cache backend
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
if maxSize <= 0 {
maxSize = 10000 // Default to 10k items
maxSize = defaultMaxSize
}
if maxMemory <= 0 {
maxMemory = 100 * 1024 * 1024 // Default to 100MB
maxMemory = defaultMaxMemory
}
if cleanupInterval <= 0 {
cleanupInterval = 5 * time.Minute
cleanupInterval = defaultCleanupInterval
}
shardCount := uint32(defaultShardCount)
// For very small caches, reduce shard count to maintain sensible per-shard limits
// Ensure each shard can hold at least 2 items for proper LRU behavior
for shardCount > 1 && maxSize/int64(shardCount) < 2 {
shardCount /= 2
}
if shardCount < 1 {
shardCount = 1
}
// Per-shard limits are soft hints; global limits are enforced
// Give shards 2x the average to allow for uneven distribution
shardMaxSize := (maxSize * 2) / int64(shardCount)
if shardMaxSize < 4 {
shardMaxSize = 4
}
shardMaxMemory := (maxMemory * 2) / int64(shardCount)
if shardMaxMemory < 4096 {
shardMaxMemory = 4096 // Minimum 4KB per shard
}
m := &MemoryCacheBackend{
items: make(map[string]*memoryCacheItem),
lruList: list.New(),
shards: make([]*cacheShard, shardCount),
shardCount: shardCount,
shardMask: shardCount - 1, // For fast modulo with power-of-2
maxSize: maxSize,
maxMemory: maxMemory,
startTime: time.Now(),
cleanupInterval: cleanupInterval,
evictionPolicy: "lru",
cleanupDone: make(chan bool),
cleanupDone: make(chan struct{}),
}
// Initialize shards
for i := uint32(0); i < shardCount; i++ {
m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory)
}
// Start cleanup goroutine
@@ -96,6 +128,12 @@ func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.
return m
}
// getShard returns the shard for a given key
func (m *MemoryCacheBackend) getShard(key string) *cacheShard {
hash := fnv32(key)
return m.shards[hash&m.shardMask]
}
// cleanupLoop runs periodic cleanup of expired items
func (m *MemoryCacheBackend) cleanupLoop() {
for {
@@ -108,20 +146,19 @@ func (m *MemoryCacheBackend) cleanupLoop() {
}
}
// cleanupExpired removes all expired items from the cache
// cleanupExpired removes all expired items from all shards
func (m *MemoryCacheBackend) cleanupExpired() {
m.mu.Lock()
defer m.mu.Unlock()
var keysToDelete []string
for key, item := range m.items {
if item.isExpired() {
keysToDelete = append(keysToDelete, key)
}
if m.closed.Load() {
return
}
for _, key := range keysToDelete {
m.deleteItemLocked(key)
totalRemoved := 0
for _, shard := range m.shards {
totalRemoved += shard.cleanup()
}
if totalRemoved > 0 {
m.evictions.Add(int64(totalRemoved))
}
}
@@ -138,35 +175,23 @@ func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{},
m.getCount.Add(1)
}()
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
shard := m.getShard(key)
value, exists, expired := shard.get(key)
if expired {
// Clean up expired item
shard.delete(key)
m.misses.Add(1)
return nil, ErrCacheMiss
}
if !exists {
m.misses.Add(1)
return nil, ErrCacheMiss
}
if item.isExpired() {
m.mu.Lock()
m.deleteItemLocked(key)
m.mu.Unlock()
m.misses.Add(1)
return nil, ErrCacheMiss
}
// Update access time and count
m.mu.Lock()
item.accessedAt = time.Now()
item.accessCount++
// Move to front of LRU list
if m.evictionPolicy == "lru" && item.element != nil {
m.lruList.MoveToFront(item.element)
}
m.mu.Unlock()
m.hits.Add(1)
return item.value, nil
return value, nil
}
// Set stores a value in the cache with optional TTL
@@ -182,113 +207,105 @@ func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interfac
m.setCount.Add(1)
}()
// Calculate item size (simplified estimation)
// Calculate item size
itemSize := int64(len(key)) + estimateValueSize(value)
m.mu.Lock()
defer m.mu.Unlock()
// Enforce global limits before adding new item
m.enforceGlobalLimits(itemSize)
// Check if we need to evict items
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
m.evictLocked()
}
// Check if key exists
if oldItem, exists := m.items[key]; exists {
m.currentMemory -= oldItem.size
if oldItem.element != nil {
m.lruList.Remove(oldItem.element)
}
} else {
m.currentSize++
}
now := time.Now()
var expiresAt time.Time
if ttl > 0 {
expiresAt = now.Add(ttl)
expiresAt = time.Now().Add(ttl)
}
item := &memoryCacheItem{
key: key,
value: value,
expiresAt: expiresAt,
createdAt: now,
accessedAt: now,
accessCount: 0,
size: itemSize,
}
shard := m.getShard(key)
shard.set(key, value, expiresAt, itemSize)
// Add to LRU list
if m.evictionPolicy == "lru" {
item.element = m.lruList.PushFront(item)
}
m.items[key] = item
m.currentMemory += itemSize
m.sets.Add(1)
return nil
}
// enforceGlobalLimits ensures global size and memory limits are respected
// by evicting from shards when necessary
func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) {
// Check and enforce size limit
for {
totalSize, totalMemory := m.getGlobalStats()
needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize
needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory
if !needsSizeEviction && !needsMemoryEviction {
break
}
// Find the shard with the most items and evict from it
evicted := m.evictFromLargestShard()
if !evicted {
break // No more items to evict
}
m.evictions.Add(1)
}
}
// getGlobalStats returns the total size and memory usage across all shards
func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) {
for _, shard := range m.shards {
size, memory := shard.stats()
totalSize += size
totalMemory += memory
}
return
}
// evictFromLargestShard evicts the globally oldest item across all shards
// This provides true LRU behavior even with sharding
func (m *MemoryCacheBackend) evictFromLargestShard() bool {
var oldestShard *cacheShard
var oldestTime time.Time
for _, shard := range m.shards {
accessTime := shard.getOldestAccessTime()
// Skip empty shards
if accessTime.IsZero() {
continue
}
// Find the shard with the oldest (earliest) access time
if oldestShard == nil || accessTime.Before(oldestTime) {
oldestTime = accessTime
oldestShard = shard
}
}
if oldestShard == nil {
return false
}
return oldestShard.evictOne()
}
// Delete removes a key from the cache
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.items[key]; !exists {
return nil
shard := m.getShard(key)
if shard.delete(key) {
m.deletes.Add(1)
}
m.deleteItemLocked(key)
m.deletes.Add(1)
return nil
}
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
if item, exists := m.items[key]; exists {
m.currentMemory -= item.size
m.currentSize--
if item.element != nil {
m.lruList.Remove(item.element)
}
delete(m.items, key)
}
}
// evictLocked evicts items based on the eviction policy (must be called with lock held)
func (m *MemoryCacheBackend) evictLocked() {
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
// Evict least recently used item
element := m.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
m.deleteItemLocked(item.key)
m.evictions.Add(1)
}
}
}
// Exists checks if a key exists in the cache
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
if m.closed.Load() {
return false, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists {
return false, nil
}
return !item.isExpired(), nil
shard := m.getShard(key)
return shard.exists(key), nil
}
// Clear removes all items from the cache
@@ -297,13 +314,9 @@ func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryCacheItem)
m.lruList = list.New()
m.currentSize = 0
m.currentMemory = 0
for _, shard := range m.shards {
shard.clear()
}
return nil
}
@@ -314,29 +327,28 @@ func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string
return nil, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
var keys []string
for key, item := range m.items {
if !item.isExpired() && matchPattern(pattern, key) {
keys = append(keys, key)
}
var allKeys []string
for _, shard := range m.shards {
keys := shard.keys(pattern)
allKeys = append(allKeys, keys...)
}
return keys, nil
return allKeys, nil
}
// Size returns the number of items in the cache
// Size returns the total number of items in the cache
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
var total int64
for _, shard := range m.shards {
size, _ := shard.stats()
total += size
}
return m.currentSize, nil
return total, nil
}
// TTL returns the remaining time-to-live for a key
@@ -345,24 +357,13 @@ func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration
return 0, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists || item.isExpired() {
shard := m.getShard(key)
ttl, exists := shard.ttl(key)
if !exists {
return 0, ErrCacheMiss
}
if item.expiresAt.IsZero() {
return 0, nil // No expiration
}
remaining := time.Until(item.expiresAt)
if remaining < 0 {
return 0, nil
}
return remaining, nil
return ttl, nil
}
// Expire updates the TTL for an existing key
@@ -371,20 +372,11 @@ func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Du
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
item, exists := m.items[key]
if !exists || item.isExpired() {
shard := m.getShard(key)
if !shard.expire(key, ttl) {
return ErrCacheMiss
}
if ttl > 0 {
item.expiresAt = time.Now().Add(ttl)
} else {
item.expiresAt = time.Time{} // Remove expiration
}
return nil
}
@@ -394,6 +386,14 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
return nil, ErrBackendUnavailable
}
// Aggregate stats from all shards
var totalSize, totalMemory int64
for _, shard := range m.shards {
size, memory := shard.stats()
totalSize += size
totalMemory += memory
}
m.mu.RLock()
lastError := m.lastError
lastErrorTime := m.lastErrorTime
@@ -417,9 +417,9 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
Deletes: m.deletes.Load(),
Errors: m.errors.Load(),
Evictions: m.evictions.Load(),
CurrentSize: m.currentSize,
CurrentSize: totalSize,
MaxSize: m.maxSize,
MemoryUsage: m.currentMemory,
MemoryUsage: totalMemory,
AverageGetLatency: avgGetLatency,
AverageSetLatency: avgSetLatency,
LastError: lastError,
@@ -446,10 +446,10 @@ func (m *MemoryCacheBackend) Close() error {
m.cleanupTicker.Stop()
close(m.cleanupDone)
m.mu.Lock()
m.items = nil
m.lruList = nil
m.mu.Unlock()
// Clear all shards
for _, shard := range m.shards {
shard.clear()
}
return nil
}
@@ -482,12 +482,28 @@ func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
}
}
// GetShardCount returns the number of shards (for testing/monitoring)
func (m *MemoryCacheBackend) GetShardCount() uint32 {
return m.shardCount
}
// GetShardStats returns per-shard statistics (for monitoring)
func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 {
stats := make([]map[string]int64, m.shardCount)
for i, shard := range m.shards {
size, memory := shard.stats()
stats[i] = map[string]int64{
"size": size,
"memory": memory,
}
}
return stats
}
// Helper functions
// estimateValueSize estimates the size of a value in bytes
func estimateValueSize(value interface{}) int64 {
// This is a simplified estimation
// In production, you might want to use a more accurate method
switch v := value.(type) {
case string:
return int64(len(v))
@@ -510,7 +526,10 @@ func matchPattern(pattern, key string) bool {
if pattern == "*" {
return true
}
// Simplified pattern matching - in production, use a proper glob library
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
// Simplified pattern matching
if len(pattern) > 0 && pattern[0] == '*' {
suffix := pattern[1:]
return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix
}
return key == pattern
}
+294
View File
@@ -0,0 +1,294 @@
package backends
import (
"container/list"
"sync"
"time"
)
// cacheShard represents a single shard of the sharded cache
// Each shard has its own lock for reduced contention
type cacheShard struct {
items map[string]*memoryCacheItem
lruList *list.List
mu sync.RWMutex
maxSize int64
maxMemory int64
size int64
memoryUsed int64
}
// newCacheShard creates a new cache shard
func newCacheShard(maxSize, maxMemory int64) *cacheShard {
return &cacheShard{
items: make(map[string]*memoryCacheItem),
lruList: list.New(),
maxSize: maxSize,
maxMemory: maxMemory,
}
}
// get retrieves a value from this shard
// Returns: value, exists, expired
func (s *cacheShard) get(key string) (interface{}, bool, bool) {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists {
return nil, false, false
}
if item.isExpired() {
return nil, true, true // exists but expired
}
// Update access time and LRU position under write lock
s.mu.Lock()
// Re-check item exists (could have been deleted)
item, exists = s.items[key]
if exists && !item.isExpired() {
item.accessedAt = time.Now()
item.accessCount++
if elem, ok := item.element.(*list.Element); ok && elem != nil {
s.lruList.MoveToFront(elem)
}
}
s.mu.Unlock()
if !exists || item.isExpired() {
return nil, false, false
}
return item.value, true, false
}
// set stores a value in this shard
func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) {
s.mu.Lock()
defer s.mu.Unlock()
// Check if we need to evict items
if s.maxSize > 0 && s.size >= s.maxSize {
s.evictLRULocked()
}
if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory {
s.evictLRULocked()
}
// Remove old item if exists
if oldItem, exists := s.items[key]; exists {
s.memoryUsed -= oldItem.size
if elem, ok := oldItem.element.(*list.Element); ok && elem != nil {
s.lruList.Remove(elem)
}
s.size--
}
now := time.Now()
item := &memoryCacheItem{
key: key,
value: value,
expiresAt: expiresAt,
createdAt: now,
accessedAt: now,
accessCount: 0,
size: size,
}
item.element = s.lruList.PushFront(item)
s.items[key] = item
s.size++
s.memoryUsed += size
}
// delete removes a key from this shard
// Returns true if the key was deleted
func (s *cacheShard) delete(key string) bool {
s.mu.Lock()
defer s.mu.Unlock()
item, exists := s.items[key]
if !exists {
return false
}
s.deleteItemLocked(item)
return true
}
// exists checks if a key exists (and is not expired)
func (s *cacheShard) exists(key string) bool {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists {
return false
}
return !item.isExpired()
}
// ttl returns the remaining TTL for a key
func (s *cacheShard) ttl(key string) (time.Duration, bool) {
s.mu.RLock()
item, exists := s.items[key]
s.mu.RUnlock()
if !exists || item.isExpired() {
return 0, false
}
if item.expiresAt.IsZero() {
return 0, true // No expiration
}
remaining := time.Until(item.expiresAt)
if remaining < 0 {
return 0, false
}
return remaining, true
}
// expire updates the TTL for an existing key
func (s *cacheShard) expire(key string, ttl time.Duration) bool {
s.mu.Lock()
defer s.mu.Unlock()
item, exists := s.items[key]
if !exists || item.isExpired() {
return false
}
if ttl > 0 {
item.expiresAt = time.Now().Add(ttl)
} else {
item.expiresAt = time.Time{} // Remove expiration
}
return true
}
// keys returns all non-expired keys matching the pattern
func (s *cacheShard) keys(pattern string) []string {
s.mu.RLock()
defer s.mu.RUnlock()
var keys []string
for key, item := range s.items {
if !item.isExpired() && matchPattern(pattern, key) {
keys = append(keys, key)
}
}
return keys
}
// clear removes all items from this shard
func (s *cacheShard) clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.items = make(map[string]*memoryCacheItem)
s.lruList.Init()
s.size = 0
s.memoryUsed = 0
}
// cleanup removes expired items
// Returns the number of items removed
func (s *cacheShard) cleanup() int {
s.mu.Lock()
defer s.mu.Unlock()
var toRemove []*memoryCacheItem
for _, item := range s.items {
if item.isExpired() {
toRemove = append(toRemove, item)
}
}
for _, item := range toRemove {
s.deleteItemLocked(item)
}
return len(toRemove)
}
// stats returns statistics for this shard
func (s *cacheShard) stats() (size, memory int64) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.size, s.memoryUsed
}
// deleteItemLocked removes an item (must be called with lock held)
func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) {
if elem, ok := item.element.(*list.Element); ok && elem != nil {
s.lruList.Remove(elem)
}
delete(s.items, item.key)
s.size--
s.memoryUsed -= item.size
}
// evictLRULocked evicts the least recently used item (must be called with lock held)
func (s *cacheShard) evictLRULocked() bool {
if s.lruList.Len() == 0 {
return false
}
element := s.lruList.Back()
if element != nil {
item, ok := element.Value.(*memoryCacheItem)
if ok {
s.deleteItemLocked(item)
return true
}
}
return false
}
// evictOne evicts one item from this shard (for global limit enforcement)
func (s *cacheShard) evictOne() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.evictLRULocked()
}
// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard
// Returns zero time if shard is empty
func (s *cacheShard) getOldestAccessTime() time.Time {
s.mu.RLock()
defer s.mu.RUnlock()
if s.lruList.Len() == 0 {
return time.Time{}
}
element := s.lruList.Back()
if element != nil {
item, ok := element.Value.(*memoryCacheItem)
if ok {
return item.accessedAt
}
}
return time.Time{}
}
// fnv32 computes FNV-1a hash of a string
// This is a fast, well-distributed hash function
func fnv32(key string) uint32 {
const (
offset32 = uint32(2166136261)
prime32 = uint32(16777619)
)
hash := offset32
for i := 0; i < len(key); i++ {
hash ^= uint32(key[i])
hash *= prime32
}
return hash
}
+283
View File
@@ -0,0 +1,283 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestShardedCache_ShardDistribution tests that keys are distributed across shards
func TestShardedCache_ShardDistribution(t *testing.T) {
t.Parallel()
// Create a cache with large enough size to have multiple shards
config := DefaultConfig()
config.MaxSize = 10000
config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add many items to see distribution
numItems := 1000
for i := 0; i < numItems; i++ {
key := fmt.Sprintf("dist-key-%d", i)
value := []byte(fmt.Sprintf("dist-value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Check that items are distributed across multiple shards
shardStats := backend.MemoryCacheBackend.GetShardStats()
nonEmptyShards := 0
for _, stat := range shardStats {
if stat["size"] > 0 {
nonEmptyShards++
}
}
// With good hash distribution, we should have items in multiple shards
assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards")
}
// TestShardedCache_ShardCount tests that shard count adapts to cache size
func TestShardedCache_ShardCount(t *testing.T) {
t.Parallel()
tests := []struct {
maxSize int
expectLowShards bool
}{
{5, true}, // Very small cache should have fewer shards
{100, true}, // Small cache should have fewer shards
{10000, false}, // Large cache should have default shards
}
for _, tt := range tests {
t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) {
config := DefaultConfig()
config.MaxSize = tt.maxSize
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
shardCount := backend.MemoryCacheBackend.GetShardCount()
if tt.expectLowShards {
assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards")
} else {
assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count")
}
})
}
}
// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key
func TestShardedCache_ConcurrentSameKey(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "concurrent-same-key"
initialValue := []byte("initial-value")
err = backend.Set(ctx, key, initialValue, time.Minute)
require.NoError(t, err)
var wg sync.WaitGroup
goroutines := 50
iterations := 100
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
// Mix of reads and writes
if j%3 == 0 {
newValue := []byte(fmt.Sprintf("value-%d-%d", id, j))
err := backend.Set(ctx, key, newValue, time.Minute)
assert.NoError(t, err)
} else {
_, _, _, err := backend.Get(ctx, key)
assert.NoError(t, err)
}
}
}(i)
}
wg.Wait()
// Key should still exist
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
}
// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained
func TestShardedCache_GlobalLRUEviction(t *testing.T) {
t.Parallel()
// Create a small cache to force eviction
config := DefaultConfig()
config.MaxSize = 10
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
// Small delay to ensure different access times
time.Sleep(time.Millisecond)
}
// Access some items to make them recently used
for i := 5; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
_, _, _, err := backend.Get(ctx, key)
require.NoError(t, err)
}
// Add more items to trigger eviction
for i := 10; i < 15; i++ {
key := fmt.Sprintf("global-lru-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Recently accessed items (5-9) should still exist
for i := 5; i < 10; i++ {
key := fmt.Sprintf("global-lru-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Recently accessed item %d should exist", i)
}
// Check eviction stats
stats := backend.GetStats()
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have evictions")
}
// TestShardedCache_StatsAggregation tests that stats are aggregated correctly
func TestShardedCache_StatsAggregation(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 10000
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items to multiple shards
numItems := 100
for i := 0; i < numItems; i++ {
key := fmt.Sprintf("stats-key-%d", i)
value := []byte(fmt.Sprintf("stats-value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
// Read some items
for i := 0; i < numItems/2; i++ {
key := fmt.Sprintf("stats-key-%d", i)
backend.Get(ctx, key)
}
// Read non-existent items
for i := 0; i < 10; i++ {
backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i))
}
stats := backend.GetStats()
// Verify stats
assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match")
assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match")
assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match")
assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match")
// Verify hit rate
hitRate := stats["hit_rate"].(float64)
expectedHitRate := float64(numItems/2) / float64(numItems/2+10)
assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match")
}
// BenchmarkShardedCache_Parallel benchmarks parallel access
func BenchmarkShardedCache_Parallel(b *testing.B) {
config := DefaultConfig()
config.MaxSize = 100000
config.MaxMemoryBytes = 100 * 1024 * 1024
backend, _ := NewMemoryBackend(config)
defer backend.Close()
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 10000; i++ {
key := fmt.Sprintf("bench-key-%d", i)
value := []byte(fmt.Sprintf("bench-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("bench-key-%d", i%10000)
backend.Get(ctx, key)
i++
}
})
}
// BenchmarkShardedCache_MixedOps benchmarks mixed operations
func BenchmarkShardedCache_MixedOps(b *testing.B) {
config := DefaultConfig()
config.MaxSize = 100000
config.MaxMemoryBytes = 100 * 1024 * 1024
backend, _ := NewMemoryBackend(config)
defer backend.Close()
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("mixed-key-%d", i%1000)
if i%3 == 0 {
value := []byte(fmt.Sprintf("mixed-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
} else {
backend.Get(ctx, key)
}
i++
}
})
}
+20 -30
View File
@@ -45,21 +45,11 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
return nil, 0, false, err
}
// Get the item directly to check TTL
m.MemoryCacheBackend.mu.RLock()
item, exists := m.MemoryCacheBackend.items[key]
m.MemoryCacheBackend.mu.RUnlock()
if !exists {
return nil, 0, false, nil
}
var ttl time.Duration
if !item.expiresAt.IsZero() {
ttl = time.Until(item.expiresAt)
if ttl < 0 {
ttl = 0
}
// Get TTL using the TTL method
ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key)
if ttlErr != nil {
// If we can't get TTL, still return the value with 0 TTL
ttl = 0
}
// Convert interface{} to []byte
@@ -68,8 +58,7 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
if bytes, ok := val.([]byte); ok {
valueBytes = bytes
} else {
// If it's not already []byte, we might need to handle other types
// For now, we'll just return an error
// If it's not already []byte, return an error
return nil, 0, false, ErrInvalidValue
}
}
@@ -123,19 +112,20 @@ func (m *MemoryBackend) GetStats() map[string]interface{} {
}
return map[string]interface{}{
"type": stats.Type,
"hits": stats.Hits,
"misses": stats.Misses,
"sets": stats.Sets,
"deletes": stats.Deletes,
"errors": stats.Errors,
"evictions": stats.Evictions,
"size": stats.CurrentSize,
"max_size": stats.MaxSize,
"memory": stats.MemoryUsage,
"hit_rate": hitRate,
"uptime": stats.Uptime,
"start_time": stats.StartTime,
"type": stats.Type,
"hits": stats.Hits,
"misses": stats.Misses,
"sets": stats.Sets,
"deletes": stats.Deletes,
"errors": stats.Errors,
"evictions": stats.Evictions,
"size": stats.CurrentSize,
"max_size": stats.MaxSize,
"memory": stats.MemoryUsage,
"hit_rate": hitRate,
"uptime": stats.Uptime,
"start_time": stats.StartTime,
"shard_count": m.MemoryCacheBackend.GetShardCount(),
}
}
+109 -13
View File
@@ -345,7 +345,7 @@ func (r *RedisBackend) prefixKey(key string) string {
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
// It checks context cancellation at multiple points to ensure fast abort when the
// caller's context is cancelled (e.g., due to request timeout).
// caller's context is canceled (e.g., due to request timeout).
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
maxRetries := 3
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
@@ -377,7 +377,7 @@ func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*Red
err = operation(conn)
r.pool.Put(conn)
// Check context after operation - if cancelled, don't bother retrying
// Check context after operation - if canceled, don't bother retrying
if ctx.Err() != nil {
return ctx.Err()
}
@@ -431,39 +431,135 @@ func isRetryableError(err error) bool {
return false
}
// SetMany stores multiple values in Redis (batch operation)
// SetMany stores multiple values in Redis using pipelining for efficiency
// This reduces N round-trips to a single round-trip
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
// For simplicity, execute sequentially (can be optimized with pipelining later)
for key, value := range items {
if err := r.Set(ctx, key, value, ttl); err != nil {
return err
if len(items) == 0 {
return nil
}
// For single items, use regular Set
if len(items) == 1 {
for key, value := range items {
return r.Set(ctx, key, value, ttl)
}
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
pipeline := conn.NewPipeline()
// Queue all SET commands
ttlSeconds := int(ttl.Seconds())
ttlMillis := ttl.Milliseconds()
for key, value := range items {
prefixedKey := r.prefixKey(key)
if ttl > 0 {
if ttlMillis < 1000 {
// Use PSETEX for sub-second TTLs
pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
} else {
// Use SETEX for larger TTLs
pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
}
} else {
pipeline.Queue("SET", prefixedKey, string(value))
}
}
// Execute pipeline
responses, err := pipeline.Execute()
if err != nil {
return fmt.Errorf("pipeline SetMany failed: %w", err)
}
// Check responses for errors (each should be "OK")
for i, resp := range responses {
if resp == nil {
continue
}
if str, ok := resp.(string); ok && str == "OK" {
continue
}
return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp)
}
return nil
}
// GetMany retrieves multiple values from Redis
// GetMany retrieves multiple values from Redis using pipelining for efficiency
// This reduces N round-trips to a single round-trip
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if r.closed.Load() {
return nil, ErrBackendClosed
}
result := make(map[string][]byte)
if len(keys) == 0 {
return make(map[string][]byte), nil
}
// For simplicity, execute sequentially
for _, key := range keys {
value, _, exists, err := r.Get(ctx, key)
// For single key, use regular Get
if len(keys) == 1 {
result := make(map[string][]byte)
value, _, exists, err := r.Get(ctx, keys[0])
if err != nil {
return nil, err
}
if exists {
result[key] = value
result[keys[0]] = value
}
return result, nil
}
conn, err := r.pool.Get(ctx)
if err != nil {
return nil, err
}
defer r.pool.Put(conn)
pipeline := conn.NewPipeline()
// Queue all GET commands
prefixedKeys := make([]string, len(keys))
for i, key := range keys {
prefixedKeys[i] = r.prefixKey(key)
pipeline.Queue("GET", prefixedKeys[i])
}
// Execute pipeline
responses, err := pipeline.Execute()
if err != nil {
return nil, fmt.Errorf("pipeline GetMany failed: %w", err)
}
// Process responses
result := make(map[string][]byte)
for i, resp := range responses {
if resp == nil {
// Key doesn't exist
r.misses.Add(1)
continue
}
value, err := RESPString(resp)
if err != nil {
// Invalid response, skip this key
r.misses.Add(1)
continue
}
r.hits.Add(1)
result[keys[i]] = []byte(value)
}
return result, nil
+10 -16
View File
@@ -9,30 +9,24 @@ import (
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
type HealthMonitor struct {
pool *ConnectionPool
config *HealthMonitorConfig
// State
healthy atomic.Bool
running atomic.Bool
lastCheckTime atomic.Int64 // Unix timestamp
// Metrics
pool *ConnectionPool
config *HealthMonitorConfig
stopChan chan struct{}
wg sync.WaitGroup
lastCheckTime atomic.Int64
consecutiveFailures atomic.Int64
totalChecks atomic.Int64
totalFailures atomic.Int64
// Lifecycle
stopChan chan struct{}
wg sync.WaitGroup
healthy atomic.Bool
running atomic.Bool
}
// HealthMonitorConfig configures the health monitor
type HealthMonitorConfig struct {
CheckInterval time.Duration // How often to check health
Timeout time.Duration // Timeout for health check
UnhealthyThreshold int // Consecutive failures before marking unhealthy
OnHealthChange func(healthy bool)
CheckInterval time.Duration
Timeout time.Duration
UnhealthyThreshold int
}
// DefaultHealthMonitorConfig returns default health monitor configuration
+461
View File
@@ -0,0 +1,461 @@
package backends
import (
"context"
"fmt"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// setupTestRedis creates a miniredis instance for testing
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) {
t.Helper()
mr, err := miniredis.Run()
require.NoError(t, err)
t.Cleanup(func() {
mr.Close()
})
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "test:",
PoolSize: 5,
})
require.NoError(t, err)
t.Cleanup(func() {
backend.Close()
})
return mr, backend
}
// TestPipeline_Basic tests basic pipeline functionality
func TestPipeline_Basic(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
config := &PoolConfig{
Address: mr.Addr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
defer pool.Put(conn)
t.Run("SingleCommand", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("SET", "single-key", "single-value")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 1)
assert.Equal(t, "OK", responses[0])
})
t.Run("MultipleCommands", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("SET", "key1", "value1")
pipeline.Queue("SET", "key2", "value2")
pipeline.Queue("SET", "key3", "value3")
pipeline.Queue("GET", "key1")
pipeline.Queue("GET", "key2")
pipeline.Queue("GET", "key3")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 6)
// First 3 are SET responses
assert.Equal(t, "OK", responses[0])
assert.Equal(t, "OK", responses[1])
assert.Equal(t, "OK", responses[2])
// Last 3 are GET responses
assert.Equal(t, "value1", responses[3])
assert.Equal(t, "value2", responses[4])
assert.Equal(t, "value3", responses[5])
})
t.Run("EmptyPipeline", func(t *testing.T) {
pipeline := conn.NewPipeline()
responses, err := pipeline.Execute()
require.NoError(t, err)
assert.Nil(t, responses)
})
t.Run("NilResponses", func(t *testing.T) {
pipeline := conn.NewPipeline()
pipeline.Queue("GET", "nonexistent-key")
responses, err := pipeline.Execute()
require.NoError(t, err)
require.Len(t, responses, 1)
assert.Nil(t, responses[0])
})
}
// TestPipeline_SetMany tests pipelined SetMany
func TestPipeline_SetMany(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
t.Run("SetManyItems", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 10; i++ {
items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Verify all items were set
for key, expectedValue := range items {
value, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key %s should exist", key)
assert.Equal(t, expectedValue, value)
}
})
t.Run("SetManyEmpty", func(t *testing.T) {
err := backend.SetMany(ctx, map[string][]byte{}, time.Minute)
require.NoError(t, err)
})
t.Run("SetManySingleItem", func(t *testing.T) {
items := map[string][]byte{
"single-setmany": []byte("single-value"),
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
value, _, exists, err := backend.Get(ctx, "single-setmany")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("single-value"), value)
})
t.Run("SetManyNoTTL", func(t *testing.T) {
items := map[string][]byte{
"nottl-key1": []byte("value1"),
"nottl-key2": []byte("value2"),
}
err := backend.SetMany(ctx, items, 0)
require.NoError(t, err)
// Keys should exist
for key := range items {
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
}
})
}
// TestPipeline_GetMany tests pipelined GetMany
func TestPipeline_GetMany(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 10; i++ {
key := fmt.Sprintf("getmany-key-%d", i)
value := []byte(fmt.Sprintf("value-%d", i))
err := backend.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
}
t.Run("GetManyExisting", func(t *testing.T) {
keys := make([]string, 10)
for i := 0; i < 10; i++ {
keys[i] = fmt.Sprintf("getmany-key-%d", i)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 10)
for i, key := range keys {
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key])
}
})
t.Run("GetManyMixed", func(t *testing.T) {
keys := []string{
"getmany-key-0", // exists
"nonexistent-key-1", // doesn't exist
"getmany-key-2", // exists
"nonexistent-key-2", // doesn't exist
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2) // Only existing keys
assert.Equal(t, []byte("value-0"), results["getmany-key-0"])
assert.Equal(t, []byte("value-2"), results["getmany-key-2"])
assert.NotContains(t, results, "nonexistent-key-1")
assert.NotContains(t, results, "nonexistent-key-2")
})
t.Run("GetManyEmpty", func(t *testing.T) {
results, err := backend.GetMany(ctx, []string{})
require.NoError(t, err)
assert.NotNil(t, results)
assert.Len(t, results, 0)
})
t.Run("GetManySingleKey", func(t *testing.T) {
results, err := backend.GetMany(ctx, []string{"getmany-key-5"})
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, []byte("value-5"), results["getmany-key-5"])
})
t.Run("GetManyAllNonexistent", func(t *testing.T) {
keys := []string{
"nonexistent-1",
"nonexistent-2",
"nonexistent-3",
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 0)
})
}
// TestPipeline_LargeBatch tests pipelining with large batches
func TestPipeline_LargeBatch(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
t.Run("SetMany100Items", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 100; i++ {
items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Verify random samples
for _, i := range []int{0, 25, 50, 75, 99} {
key := fmt.Sprintf("large-batch-%d", i)
value, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value)
}
})
t.Run("GetMany100Items", func(t *testing.T) {
keys := make([]string, 100)
for i := 0; i < 100; i++ {
keys[i] = fmt.Sprintf("large-batch-%d", i)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 100)
})
}
// TestPipeline_Stats tests that stats are tracked correctly with pipelining
func TestPipeline_Stats(t *testing.T) {
t.Parallel()
_, backend := setupTestRedis(t)
ctx := context.Background()
// Set some items
items := map[string][]byte{
"stats-key-1": []byte("value1"),
"stats-key-2": []byte("value2"),
}
err := backend.SetMany(ctx, items, time.Minute)
require.NoError(t, err)
// Get items (some exist, some don't)
keys := []string{
"stats-key-1",
"stats-key-2",
"stats-key-nonexistent",
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2)
// Check stats
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Equal(t, int64(2), hits, "Should have 2 hits")
assert.Equal(t, int64(1), misses, "Should have 1 miss")
}
// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining
func BenchmarkPipeline_SetMany(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Prepare items
items := make(map[string][]byte)
for i := 0; i < 100; i++ {
items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = backend.SetMany(ctx, items, time.Minute)
}
}
// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining
func BenchmarkPipeline_GetMany(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Pre-populate cache
for i := 0; i < 100; i++ {
key := fmt.Sprintf("bench-key-%d", i)
value := []byte(fmt.Sprintf("bench-value-%d", i))
backend.Set(ctx, key, value, time.Hour)
}
// Prepare keys
keys := make([]string, 100)
for i := 0; i < 100; i++ {
keys[i] = fmt.Sprintf("bench-key-%d", i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = backend.GetMany(ctx, keys)
}
}
// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations
func BenchmarkPipeline_VsSequential(b *testing.B) {
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
defer mr.Close()
backend, err := NewRedisBackend(&Config{
RedisAddr: mr.Addr(),
RedisPrefix: "bench:",
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
// Prepare items
items := make(map[string][]byte)
keys := make([]string, 50)
for i := 0; i < 50; i++ {
key := fmt.Sprintf("compare-key-%d", i)
keys[i] = key
items[key] = []byte(fmt.Sprintf("compare-value-%d", i))
}
b.Run("Pipelined-Set", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = backend.SetMany(ctx, items, time.Minute)
}
})
b.Run("Sequential-Set", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for key, value := range items {
_ = backend.Set(ctx, key, value, time.Minute)
}
}
})
// Pre-populate for get benchmarks
_ = backend.SetMany(ctx, items, time.Hour)
b.Run("Pipelined-Get", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = backend.GetMany(ctx, keys)
}
})
b.Run("Sequential-Get", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, key := range keys {
_, _, _, _ = backend.Get(ctx, key)
}
}
})
}
+117
View File
@@ -336,3 +336,120 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
_, err := conn.Do("PING")
return err == nil
}
// Pipeline represents a Redis pipeline for batch operations
// It queues multiple commands and executes them in a single round-trip
type Pipeline struct {
conn *RedisConn
commands []pipelineCommand
mu sync.Mutex
}
// pipelineCommand represents a single command in the pipeline
type pipelineCommand struct {
command string
args []string
}
// NewPipeline creates a new pipeline for the connection
func (c *RedisConn) NewPipeline() *Pipeline {
return &Pipeline{
conn: c,
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
}
}
// Queue adds a command to the pipeline
func (p *Pipeline) Queue(command string, args ...string) {
p.mu.Lock()
defer p.mu.Unlock()
p.commands = append(p.commands, pipelineCommand{
command: command,
args: args,
})
}
// Execute sends all queued commands and returns all responses
// Returns a slice of responses in the same order as commands were queued
func (p *Pipeline) Execute() ([]interface{}, error) {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.commands) == 0 {
return nil, nil
}
if p.conn.closed.Load() {
return nil, ErrBackendClosed
}
p.conn.mu.Lock()
defer p.conn.mu.Unlock()
// Set write timeout for all commands
if p.conn.writeTimeout > 0 {
// Use longer timeout for batch operations
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
if timeout > 30*time.Second {
timeout = 30 * time.Second // Cap at 30 seconds
}
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
}
// Write all commands (pipelining - send all before reading any responses)
writer := NewRESPWriter(p.conn.conn)
for _, cmd := range p.commands {
cmdArgs := append([]string{cmd.command}, cmd.args...)
if err := writer.WriteCommand(cmdArgs...); err != nil {
writer.Release()
p.conn.closed.Store(true)
return nil, fmt.Errorf("pipeline write error: %w", err)
}
}
writer.Release()
// Set read timeout for all responses
if p.conn.readTimeout > 0 {
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
if timeout > 30*time.Second {
timeout = 30 * time.Second
}
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
}
// Read all responses
responses := make([]interface{}, len(p.commands))
reader := NewRESPReader(p.conn.conn)
defer reader.Release()
for i := range p.commands {
resp, err := reader.ReadResponse()
if err != nil {
// For nil responses, store nil instead of erroring
if errors.Is(err, ErrNilResponse) {
responses[i] = nil
continue
}
p.conn.closed.Store(true)
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
}
responses[i] = resp
}
return responses, nil
}
// Clear resets the pipeline for reuse
func (p *Pipeline) Clear() {
p.mu.Lock()
defer p.mu.Unlock()
p.commands = p.commands[:0]
}
// Len returns the number of queued commands
func (p *Pipeline) Len() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.commands)
}
+1 -1
View File
@@ -201,7 +201,7 @@ func TestConnectionPool_ContextCancellation(t *testing.T) {
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Try to get another with cancelled context
// Try to get another with canceled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
+15 -34
View File
@@ -7,52 +7,34 @@ import (
"io"
"strconv"
"strings"
"sync"
)
// RESP (REdis Serialization Protocol) implementation
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
//
// NOTE: sync.Pool was intentionally removed for Yaegi compatibility.
// Yaegi (Traefik's Go interpreter) has issues with sync.Pool and reflection
// that cause "reflect: call of reflect.Value.Field on zero Value" panics.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/120
var (
ErrInvalidRESP = errors.New("invalid RESP response")
ErrNilResponse = errors.New("nil response")
)
// Object pools for memory optimization - reduces allocations by 50-70%
var (
readerPool = sync.Pool{
New: func() interface{} {
return &RESPReader{
r: bufio.NewReaderSize(nil, 4096),
}
},
}
writerPool = sync.Pool{
New: func() interface{} {
return &RESPWriter{
w: nil,
}
},
}
)
// RESPWriter writes RESP protocol messages
type RESPWriter struct {
w io.Writer
}
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
// NewRESPWriter creates a new RESP writer
func NewRESPWriter(w io.Writer) *RESPWriter {
writer := writerPool.Get().(*RESPWriter)
writer.w = w
return writer
return &RESPWriter{w: w}
}
// Release returns the writer to the pool for reuse
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
func (w *RESPWriter) Release() {
w.w = nil
writerPool.Put(w)
// No-op: pooling removed for Yaegi compatibility
}
// WriteCommand writes a Redis command in RESP array format
@@ -78,17 +60,16 @@ type RESPReader struct {
r *bufio.Reader
}
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
// NewRESPReader creates a new RESP reader
func NewRESPReader(r io.Reader) *RESPReader {
reader := readerPool.Get().(*RESPReader)
reader.r.Reset(r)
return reader
return &RESPReader{
r: bufio.NewReaderSize(r, 4096),
}
}
// Release returns the reader to the pool for reuse
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
func (r *RESPReader) Release() {
r.r.Reset(nil)
readerPool.Put(r)
// No-op: pooling removed for Yaegi compatibility
}
// ReadResponse reads a RESP response and returns the parsed value
+5 -5
View File
@@ -15,8 +15,8 @@ import (
func TestRESPWriter_WriteCommand(t *testing.T) {
tests := []struct {
name string
args []string
expected string
args []string
}{
{
name: "Simple command",
@@ -205,9 +205,9 @@ func TestRESPReader_ReadInteger(t *testing.T) {
// TestRESPReader_ReadBulkString tests reading bulk strings
func TestRESPReader_ReadBulkString(t *testing.T) {
tests := []struct {
expected interface{}
name string
input string
expected interface{}
wantErr bool
isNil bool
}{
@@ -440,10 +440,10 @@ func TestRESPHelpers(t *testing.T) {
// TestRESPRoundTrip tests full round-trip encoding/decoding
func TestRESPRoundTrip(t *testing.T) {
tests := []struct {
name string
command []string
response string
expected interface{}
name string
response string
command []string
}{
{
name: "PING command",
+183
View File
@@ -0,0 +1,183 @@
package backends
import (
"context"
"sync"
"sync/atomic"
"time"
)
// SingleflightCache wraps a CacheBackend with singleflight deduplication
// to prevent thundering herd problems when multiple concurrent requests
// try to fetch the same uncached key.
type SingleflightCache struct {
backend CacheBackend
mu sync.Mutex
calls map[string]*singleflightCall
// Metrics
deduplicatedCalls atomic.Int64
totalCalls atomic.Int64
}
// singleflightCall represents an in-flight or completed fetch call
type singleflightCall struct {
wg sync.WaitGroup
val []byte
ttl time.Duration
err error
done bool
}
// NewSingleflightCache creates a new singleflight-wrapped cache backend
func NewSingleflightCache(backend CacheBackend) *SingleflightCache {
return &SingleflightCache{
backend: backend,
calls: make(map[string]*singleflightCall),
}
}
// Fetcher is a function type that fetches data when cache misses
type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error)
// GetOrFetch retrieves a value from cache or calls the fetcher exactly once
// per key when there's a cache miss. Concurrent calls for the same key will
// wait for the first call to complete and share its result.
func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) {
s.totalCalls.Add(1)
// Try cache first
value, _, exists, err := s.backend.Get(ctx, key)
if err != nil {
return nil, err
}
if exists {
return value, nil
}
// Cache miss - use singleflight
s.mu.Lock()
// Check if there's already an in-flight call for this key
if call, ok := s.calls[key]; ok {
s.mu.Unlock()
s.deduplicatedCalls.Add(1)
// Wait for the in-flight call to complete
call.wg.Wait()
// Check context cancellation
if ctx.Err() != nil {
return nil, ctx.Err()
}
return call.val, call.err
}
// Create new call
call := &singleflightCall{}
call.wg.Add(1)
s.calls[key] = call
s.mu.Unlock()
// Execute the fetcher
call.val, call.ttl, call.err = fetcher(ctx)
call.done = true
// If successful, store in cache
if call.err == nil && call.val != nil {
// Use a background context for cache storage to ensure it completes
// even if the original context is canceled
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
cancel()
}
// Signal waiting goroutines
call.wg.Done()
// Clean up the call from the map after a short delay
// This allows late arrivals to still benefit from the result
go func() {
time.Sleep(100 * time.Millisecond)
s.mu.Lock()
if c, ok := s.calls[key]; ok && c == call {
delete(s.calls, key)
}
s.mu.Unlock()
}()
return call.val, call.err
}
// Get retrieves a value from the underlying cache backend
func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
return s.backend.Get(ctx, key)
}
// Set stores a value in the underlying cache backend
func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return s.backend.Set(ctx, key, value, ttl)
}
// Delete removes a key from the underlying cache backend
func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) {
return s.backend.Delete(ctx, key)
}
// Exists checks if a key exists in the underlying cache backend
func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) {
return s.backend.Exists(ctx, key)
}
// Clear removes all keys from the underlying cache backend
func (s *SingleflightCache) Clear(ctx context.Context) error {
return s.backend.Clear(ctx)
}
// GetStats returns cache statistics including singleflight metrics
func (s *SingleflightCache) GetStats() map[string]interface{} {
stats := s.backend.GetStats()
// Add singleflight-specific stats
totalCalls := s.totalCalls.Load()
deduped := s.deduplicatedCalls.Load()
stats["singleflight_total_calls"] = totalCalls
stats["singleflight_deduplicated"] = deduped
if totalCalls > 0 {
stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls)
} else {
stats["singleflight_dedup_rate"] = float64(0)
}
s.mu.Lock()
stats["singleflight_inflight"] = len(s.calls)
s.mu.Unlock()
return stats
}
// Close shuts down the cache backend
func (s *SingleflightCache) Close() error {
return s.backend.Close()
}
// Ping checks if the backend is healthy
func (s *SingleflightCache) Ping(ctx context.Context) error {
return s.backend.Ping(ctx)
}
// GetBackend returns the underlying cache backend
func (s *SingleflightCache) GetBackend() CacheBackend {
return s.backend
}
// ResetStats resets the singleflight statistics
func (s *SingleflightCache) ResetStats() {
s.totalCalls.Store(0)
s.deduplicatedCalls.Store(0)
}
// Ensure SingleflightCache implements CacheBackend
var _ CacheBackend = (*SingleflightCache)(nil)
+510
View File
@@ -0,0 +1,510 @@
package backends
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality
func TestSingleflightCache_BasicGetOrFetch(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
t.Run("CacheHit", func(t *testing.T) {
key := "existing-key"
value := []byte("existing-value")
// Pre-populate cache
err := cache.Set(ctx, key, value, time.Minute)
require.NoError(t, err)
var fetchCalled bool
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCalled = true
return []byte("fetched-value"), time.Minute, nil
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
require.NoError(t, err)
assert.Equal(t, value, result)
assert.False(t, fetchCalled, "Fetcher should not be called on cache hit")
})
t.Run("CacheMiss", func(t *testing.T) {
key := "missing-key"
expectedValue := []byte("fetched-value")
var fetchCalled bool
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCalled = true
return expectedValue, time.Minute, nil
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
require.NoError(t, err)
assert.Equal(t, expectedValue, result)
assert.True(t, fetchCalled, "Fetcher should be called on cache miss")
// Verify value was stored in cache
cached, _, exists, err := cache.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, expectedValue, cached)
})
t.Run("FetcherError", func(t *testing.T) {
key := "error-key"
expectedErr := errors.New("fetch failed")
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return nil, 0, expectedErr
}
result, err := cache.GetOrFetch(ctx, key, fetcher)
assert.Error(t, err)
assert.Equal(t, expectedErr, err)
assert.Nil(t, result)
// Verify nothing was stored in cache
_, _, exists, err := cache.Get(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
}
// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated
func TestSingleflightCache_Deduplication(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
key := "dedup-key"
expectedValue := []byte("dedup-value")
var fetchCount atomic.Int32
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
// Simulate slow fetch
time.Sleep(100 * time.Millisecond)
return expectedValue, time.Minute, nil
}
// Launch multiple concurrent requests
concurrency := 10
var wg sync.WaitGroup
results := make([][]byte, concurrency)
errs := make([]error, concurrency)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
}(i)
}
wg.Wait()
// Verify all requests got the same result
for i := 0; i < concurrency; i++ {
assert.NoError(t, errs[i])
assert.Equal(t, expectedValue, results[i])
}
// Verify fetcher was only called once
assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once")
// Verify deduplication stats
stats := cache.GetStats()
deduped := stats["singleflight_deduplicated"].(int64)
assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls")
}
// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel
func TestSingleflightCache_DifferentKeys(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
var fetchCount atomic.Int32
fetchStarted := make(chan struct{}, 3)
fetchComplete := make(chan struct{})
fetcher := func(key string) Fetcher {
return func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
fetchStarted <- struct{}{}
<-fetchComplete // Wait for signal
return []byte("value-" + key), time.Minute, nil
}
}
// Launch concurrent requests for different keys
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
key := fmt.Sprintf("key-%d", idx)
_, _ = cache.GetOrFetch(ctx, key, fetcher(key))
}(i)
}
// Wait for all fetches to start
for i := 0; i < 3; i++ {
<-fetchStarted
}
// All 3 fetches should be running in parallel
assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel")
// Release all fetches
close(fetchComplete)
wg.Wait()
}
// TestSingleflightCache_ContextCancellation tests context cancellation
func TestSingleflightCache_ContextCancellation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
key := "cancel-key"
fetchStarted := make(chan struct{})
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
close(fetchStarted)
// Simulate slow fetch
time.Sleep(500 * time.Millisecond)
return []byte("value"), time.Minute, nil
}
// Start first request with long timeout
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ctx := context.Background()
_, _ = cache.GetOrFetch(ctx, key, fetcher)
}()
// Wait for fetch to start
<-fetchStarted
// Start second request with short timeout
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = cache.GetOrFetch(ctx, key, fetcher)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
wg.Wait()
}
// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated
func TestSingleflightCache_ErrorPropagation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
key := "error-prop-key"
expectedErr := errors.New("intentional error")
var fetchCount atomic.Int32
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
fetchCount.Add(1)
time.Sleep(50 * time.Millisecond)
return nil, 0, expectedErr
}
// Launch multiple concurrent requests
concurrency := 5
var wg sync.WaitGroup
errs := make([]error, concurrency)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
}(i)
}
wg.Wait()
// Verify all requests got the same error
for i := 0; i < concurrency; i++ {
assert.Error(t, errs[i])
assert.Equal(t, expectedErr, errs[i])
}
// Verify fetcher was only called once
assert.Equal(t, int32(1), fetchCount.Load())
}
// TestSingleflightCache_PassthroughMethods tests that passthrough methods work
func TestSingleflightCache_PassthroughMethods(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
t.Run("Set", func(t *testing.T) {
err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute)
require.NoError(t, err)
val, _, exists, err := cache.Get(ctx, "set-key")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("set-value"), val)
})
t.Run("Get", func(t *testing.T) {
err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute)
require.NoError(t, err)
val, ttl, exists, err := cache.Get(ctx, "get-key")
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, []byte("get-value"), val)
assert.Greater(t, ttl, time.Duration(0))
})
t.Run("Delete", func(t *testing.T) {
err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute)
require.NoError(t, err)
deleted, err := cache.Delete(ctx, "delete-key")
require.NoError(t, err)
assert.True(t, deleted)
exists, err := cache.Exists(ctx, "delete-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Exists", func(t *testing.T) {
exists, err := cache.Exists(ctx, "nonexistent")
require.NoError(t, err)
assert.False(t, exists)
err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute)
require.NoError(t, err)
exists, err = cache.Exists(ctx, "exists-key")
require.NoError(t, err)
assert.True(t, exists)
})
t.Run("Clear", func(t *testing.T) {
err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute)
require.NoError(t, err)
err = cache.Clear(ctx)
require.NoError(t, err)
exists, err := cache.Exists(ctx, "clear-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Ping", func(t *testing.T) {
err := cache.Ping(ctx)
require.NoError(t, err)
})
}
// TestSingleflightCache_Stats tests statistics tracking
func TestSingleflightCache_Stats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
// Make some calls
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(50 * time.Millisecond)
return []byte("value"), time.Minute, nil
}
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.GetOrFetch(ctx, "stats-key", fetcher)
}()
}
wg.Wait()
stats := cache.GetStats()
// Check singleflight stats exist
assert.Contains(t, stats, "singleflight_total_calls")
assert.Contains(t, stats, "singleflight_deduplicated")
assert.Contains(t, stats, "singleflight_dedup_rate")
assert.Contains(t, stats, "singleflight_inflight")
// Verify values
assert.Equal(t, int64(5), stats["singleflight_total_calls"])
assert.Equal(t, int64(4), stats["singleflight_deduplicated"])
// Also check underlying backend stats are included
assert.Contains(t, stats, "hits")
assert.Contains(t, stats, "misses")
}
// TestSingleflightCache_ResetStats tests stats reset
func TestSingleflightCache_ResetStats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return []byte("value"), time.Minute, nil
}
// Make some calls
_, _ = cache.GetOrFetch(ctx, "key1", fetcher)
_, _ = cache.GetOrFetch(ctx, "key2", fetcher)
stats := cache.GetStats()
assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0))
// Reset stats
cache.ResetStats()
stats = cache.GetStats()
assert.Equal(t, int64(0), stats["singleflight_total_calls"])
assert.Equal(t, int64(0), stats["singleflight_deduplicated"])
}
// TestSingleflightCache_GetBackend tests GetBackend method
func TestSingleflightCache_GetBackend(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
cache := NewSingleflightCache(backend)
assert.Equal(t, backend, cache.GetBackend())
}
// BenchmarkSingleflightCache_Sequential benchmarks sequential access
func BenchmarkSingleflightCache_Sequential(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := fmt.Sprintf("key-%d", i%100)
_, _ = cache.GetOrFetch(ctx, key, fetcher)
}
}
// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access
func BenchmarkSingleflightCache_Concurrent(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(time.Millisecond) // Simulate slow fetch
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication
_, _ = cache.GetOrFetch(ctx, key, fetcher)
i++
}
})
}
// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario
func BenchmarkSingleflightCache_HighContention(b *testing.B) {
backend, _ := NewMemoryBackend(DefaultConfig())
defer backend.Close()
cache := NewSingleflightCache(backend)
ctx := context.Background()
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing
return []byte("benchmark-value"), time.Minute, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
// All goroutines hit the same key
_, _ = cache.GetOrFetch(ctx, "hot-key", fetcher)
}
})
}
+28 -40
View File
@@ -33,21 +33,19 @@ type Logger interface {
// Config provides configuration for the cache
type Config struct {
Logger Logger
JWKConfig *JWKConfig
MetadataConfig *MetadataConfig
TokenConfig *TokenConfig
Type Type
MaxSize int
MaxMemoryBytes int64
DefaultTTL time.Duration
CleanupInterval time.Duration
EnableCompression bool
MaxMemoryBytes int64
MaxSize int
EnableMetrics bool
EnableAutoCleanup bool
EnableMemoryLimit bool
Logger Logger
// Type-specific configurations
TokenConfig *TokenConfig
MetadataConfig *MetadataConfig
JWKConfig *JWKConfig
EnableCompression bool
}
// TokenConfig provides token-specific cache configuration
@@ -59,11 +57,11 @@ type TokenConfig struct {
// MetadataConfig provides metadata-specific cache configuration
type MetadataConfig struct {
SecurityCriticalFields []string
GracePeriod time.Duration
ExtendedGracePeriod time.Duration
MaxGracePeriod time.Duration
SecurityCriticalMaxGracePeriod time.Duration
SecurityCriticalFields []string
}
// JWKConfig provides JWK-specific cache configuration
@@ -75,45 +73,35 @@ type JWKConfig struct {
// Item represents a single cache entry
type Item struct {
Key string
Value interface{}
Size int64
ExpiresAt time.Time
LastAccessed time.Time
AccessCount int64
Value interface{}
Metadata map[string]interface{}
element *list.Element
Key string
CacheType Type
// Type-specific metadata
Metadata map[string]interface{}
// LRU list element reference
element *list.Element
Size int64
AccessCount int64
}
// Cache provides a single, unified cache implementation
type Cache struct {
mu sync.RWMutex
items map[string]*Item
lruList *list.List
config Config
logger Logger
// Memory management
config Config
ctx context.Context
logger Logger
cancel context.CancelFunc
lruList *list.List
items map[string]*Item
stopCleanup chan bool
wg sync.WaitGroup
currentSize int64
currentMemory int64
// Metrics
hits int64
misses int64
evictions int64
sets int64
// Lifecycle management
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
stopCleanup chan bool
closed int32
hits int64
misses int64
evictions int64
sets int64
mu sync.RWMutex
closed int32
}
// DefaultConfig returns a default cache configuration
+11 -11
View File
@@ -1750,19 +1750,19 @@ func TestAdvancedEdgeCases(t *testing.T) {
// Test with various data types
testCases := []struct {
key string
value interface{}
key string
}{
{"string", "test string"},
{"int", 42},
{"float", 3.14159},
{"bool", true},
{"slice", []string{"a", "b", "c"}},
{"map", map[string]int{"one": 1, "two": 2}},
{"nil", nil},
{"empty-string", ""},
{"empty-slice", []string{}},
{"empty-map", map[string]interface{}{}},
{key: "string", value: "test string"},
{key: "int", value: 42},
{key: "float", value: 3.14159},
{key: "bool", value: true},
{key: "slice", value: []string{"a", "b", "c"}},
{key: "map", value: map[string]int{"one": 1, "two": 2}},
{key: "nil", value: nil},
{key: "empty-string", value: ""},
{key: "empty-slice", value: []string{}},
{key: "empty-map", value: map[string]interface{}{}},
}
for _, tc := range testCases {
+3 -8
View File
@@ -7,22 +7,17 @@ import (
// Manager manages multiple cache instances with singleton pattern
type Manager struct {
mu sync.RWMutex
// Core caches
logger Logger
tokenCache *Cache
metadataCache *Cache
jwkCache *Cache
sessionCache *Cache
generalCache *Cache
// Typed wrappers
typedToken *TokenCache
typedMetadata *MetadataCache
typedJWK *JWKCache
typedSession *SessionCache
logger Logger
mu sync.RWMutex
}
var (
@@ -237,7 +232,7 @@ func (m *Manager) Close() error {
var firstErr error
if err := m.tokenCache.Close(); err != nil && firstErr == nil {
if err := m.tokenCache.Close(); err != nil {
firstErr = err
}
if err := m.metadataCache.Close(); err != nil && firstErr == nil {
+23 -42
View File
@@ -48,23 +48,12 @@ func (s State) String() string {
// CircuitBreakerConfig holds configuration for the circuit breaker
type CircuitBreakerConfig struct {
// MaxFailures is the number of consecutive failures before opening the circuit
MaxFailures int
// FailureThreshold is the failure rate threshold (0.0 to 1.0)
FailureThreshold float64
// Timeout is how long the circuit stays open before trying half-open
Timeout time.Duration
// HalfOpenMaxRequests is the number of requests allowed in half-open state
OnStateChange func(from, to State)
MaxFailures int
FailureThreshold float64
Timeout time.Duration
HalfOpenMaxRequests int
// ResetTimeout is how long to wait before resetting counters in closed state
ResetTimeout time.Duration
// OnStateChange is called when the circuit breaker changes state
OnStateChange func(from, to State)
ResetTimeout time.Duration
}
// DefaultCircuitBreakerConfig returns default configuration
@@ -80,28 +69,20 @@ func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
// CircuitBreaker implements the circuit breaker pattern
type CircuitBreaker struct {
config *CircuitBreakerConfig
// State management
state atomic.Int32
lastStateChange time.Time
stateMu sync.RWMutex
// Failure tracking
consecutiveFailures atomic.Int32
totalRequests atomic.Int64
nextRetryTime time.Time
lastStateChange time.Time
lastSuccessTime time.Time
lastFailureTime time.Time
config *CircuitBreakerConfig
totalFailures atomic.Int64
totalRequests atomic.Int64
stateTransitions atomic.Int64
rejectedRequests atomic.Int64
stateMu sync.RWMutex
timeMu sync.RWMutex
halfOpenRequests atomic.Int32
// Timing
lastFailureTime time.Time
lastSuccessTime time.Time
nextRetryTime time.Time
timeMu sync.RWMutex
// Metrics
stateTransitions atomic.Int64
rejectedRequests atomic.Int64
consecutiveFailures atomic.Int32
state atomic.Int32
}
// NewCircuitBreaker creates a new circuit breaker
@@ -313,17 +294,17 @@ func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
// CircuitBreakerStats holds statistics for the circuit breaker
type CircuitBreakerStats struct {
State State
ConsecutiveFailures int32
LastFailureTime time.Time
LastSuccessTime time.Time
LastStateChange time.Time
NextRetryTime time.Time
TotalRequests int64
TotalFailures int64
SuccessRate float64
RejectedRequests int64
StateTransitions int64
LastFailureTime time.Time
LastSuccessTime time.Time
LastStateChange time.Time
NextRetryTime time.Time
State State
ConsecutiveFailures int32
}
// IsHealthy returns true if the circuit breaker is in a healthy state
+1 -1
View File
@@ -28,8 +28,8 @@ type mockBackend struct {
}
type mockEntry struct {
value []byte
expiresAt time.Time
value []byte
}
func newMockBackend() *mockBackend {
+28 -49
View File
@@ -41,26 +41,13 @@ func (h HealthStatus) String() string {
// HealthCheckConfig holds configuration for the health checker
type HealthCheckConfig struct {
// CheckInterval is how often to check health
CheckInterval time.Duration
// Timeout is the timeout for each health check
Timeout time.Duration
// HealthyThreshold is the number of consecutive successes to become healthy
HealthyThreshold int
// UnhealthyThreshold is the number of consecutive failures to become unhealthy
OnStatusChange func(from, to HealthStatus)
CheckFunc func(ctx context.Context) error
CheckInterval time.Duration
Timeout time.Duration
HealthyThreshold int
UnhealthyThreshold int
// DegradedThreshold is the latency threshold in ms to mark as degraded
DegradedThreshold time.Duration
// OnStatusChange is called when health status changes
OnStatusChange func(from, to HealthStatus)
// CheckFunc is the function to check health
CheckFunc func(ctx context.Context) error
DegradedThreshold time.Duration
}
// DefaultHealthCheckConfig returns default configuration
@@ -76,31 +63,23 @@ func DefaultHealthCheckConfig() *HealthCheckConfig {
// HealthChecker monitors the health of a backend
type HealthChecker struct {
config *HealthCheckConfig
// Status tracking
status atomic.Int32
consecutiveSuccesses atomic.Int32
lastCheckTime time.Time
lastSuccessTime time.Time
lastFailureTime time.Time
config *HealthCheckConfig
stopChan chan struct{}
ticker *time.Ticker
wg sync.WaitGroup
statusChanges atomic.Int64
totalChecks atomic.Int64
totalSuccesses atomic.Int64
totalFailures atomic.Int64
averageLatency atomic.Int64
timeMu sync.RWMutex
consecutiveFailures atomic.Int32
// Timing
lastCheckTime time.Time
lastSuccessTime time.Time
lastFailureTime time.Time
averageLatency atomic.Int64
timeMu sync.RWMutex
// Metrics
totalChecks atomic.Int64
totalSuccesses atomic.Int64
totalFailures atomic.Int64
statusChanges atomic.Int64
// Lifecycle
ticker *time.Ticker
stopChan chan struct{}
stopped atomic.Bool
wg sync.WaitGroup
consecutiveSuccesses atomic.Int32
stopped atomic.Bool
status atomic.Int32
}
// NewHealthChecker creates a new health checker
@@ -342,19 +321,19 @@ func (hc *HealthChecker) Stats() HealthCheckerStats {
// HealthCheckerStats holds statistics for the health checker
type HealthCheckerStats struct {
Status HealthStatus
ConsecutiveSuccesses int32
ConsecutiveFailures int32
LastCheckTime time.Time
LastFailureTime time.Time
LastSuccessTime time.Time
TotalChecks int64
TotalSuccesses int64
TotalFailures int64
SuccessRate float64
AverageLatency time.Duration
StatusChanges int64
LastCheckTime time.Time
LastSuccessTime time.Time
LastFailureTime time.Time
HealthScore float64
Status HealthStatus
ConsecutiveFailures int32
ConsecutiveSuccesses int32
}
// Reset resets the health checker statistics
+7 -11
View File
@@ -12,20 +12,16 @@ import (
// HealthCheckBackend wraps a cache backend with health checking
type HealthCheckBackend struct {
backend backends.CacheBackend
config *HealthCheckConfig
// Health tracking
lastCheck time.Time
backend backends.CacheBackend
ctx context.Context
config *HealthCheckConfig
cancel context.CancelFunc
wg sync.WaitGroup
checkMutex sync.RWMutex
status atomic.Int32
consecutiveFails atomic.Int32
consecutiveOK atomic.Int32
lastCheck time.Time
checkMutex sync.RWMutex
// Lifecycle
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewHealthCheckBackend creates a new health check wrapped backend
+2 -2
View File
@@ -292,12 +292,12 @@ type SessionCache struct {
// SessionData represents session information
type SessionData struct {
ExpiresAt time.Time `json:"expires_at"`
Claims map[string]interface{} `json:"claims"`
ID string `json:"id"`
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresAt time.Time `json:"expires_at"`
Claims map[string]interface{} `json:"claims"`
}
// NewSessionCache creates a new session cache
+1 -1
View File
@@ -11,10 +11,10 @@ import (
// Mock logger for testing
type mockLogger struct {
mu sync.Mutex
logs []string
errLogs []string
debugLog []string
mu sync.Mutex
}
func (m *mockLogger) Logf(format string, args ...interface{}) {
+11 -11
View File
@@ -19,20 +19,20 @@ type Logger interface {
// BackgroundTask represents a recurring background task
type BackgroundTask struct {
name string
interval time.Duration
taskFunc func()
lastRun time.Time
logger Logger
ctx context.Context
ticker *time.Ticker
stopChan chan bool
isRunning int32
logger Logger
waitGroup *sync.WaitGroup
lastRun time.Time
taskFunc func()
cancelFunc context.CancelFunc
name string
runCount int64
errorCount int64
interval time.Duration
mu sync.RWMutex
ctx context.Context
cancelFunc context.CancelFunc
isRunning int32
}
// NewBackgroundTask creates a new background task
@@ -183,11 +183,11 @@ func (bt *BackgroundTask) IsRunning() bool {
// TaskRegistry manages all background tasks
type TaskRegistry struct {
tasks map[string]*BackgroundTask
mu sync.RWMutex
logger Logger
maxTasks int
tasks map[string]*BackgroundTask
circuitBreaker *TaskCircuitBreaker
maxTasks int
mu sync.RWMutex
}
// globalTaskRegistry is the singleton task registry
+14 -14
View File
@@ -11,14 +11,14 @@ import (
// TaskCircuitBreaker prevents task creation failures from cascading
type TaskCircuitBreaker struct {
lastFailureTime time.Time
logger Logger
taskFailures map[string]int32
timeout time.Duration
mu sync.RWMutex
failureThreshold int32
failureCount int32
lastFailureTime time.Time
timeout time.Duration
state int32 // 0: closed, 1: open
logger Logger
mu sync.RWMutex
taskFailures map[string]int32
state int32
}
// CircuitBreakerState represents the state of the circuit breaker
@@ -140,14 +140,14 @@ func (cb *TaskCircuitBreaker) GetState() CircuitBreakerState {
// TaskMemoryMonitor monitors memory usage and can trigger cleanup
type TaskMemoryMonitor struct {
lastCheck time.Time
logger Logger
registry *TaskRegistry
stopChan chan bool
memoryThreshold uint64
checkInterval time.Duration
isMonitoring int32
stopChan chan bool
lastCheck time.Time
mu sync.RWMutex
isMonitoring int32
}
var (
@@ -310,13 +310,13 @@ func (tmm *TaskMemoryMonitor) GetStats() map[string]interface{} {
// WorkerPool manages a pool of worker goroutines for task execution
type WorkerPool struct {
workers int
taskQueue chan func()
workerWg sync.WaitGroup
isRunning int32
logger Logger
taskQueue chan func()
stopChan chan bool
metrics WorkerPoolMetrics
workerWg sync.WaitGroup
workers int
isRunning int32
}
// WorkerPoolMetrics tracks worker pool performance
@@ -397,7 +397,7 @@ func (wp *WorkerPool) Submit(task func()) error {
}
// worker is the main worker routine
func (wp *WorkerPool) worker(id int) {
func (wp *WorkerPool) worker(_ int) {
defer wp.workerWg.Done()
for {
+155
View File
@@ -0,0 +1,155 @@
package dcrstorage
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)
// FileStore implements Store using file-based storage.
// This is the default storage backend for backward compatibility with existing deployments.
// For distributed environments, consider using RedisStore instead.
type FileStore struct {
basePath string
logger Logger
mu sync.RWMutex
}
// NewFileStore creates a new file-based credentials store.
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
func NewFileStore(basePath string, logger Logger) *FileStore {
if basePath == "" {
basePath = "/tmp/oidc-client-credentials.json"
}
if logger == nil {
logger = NoOpLogger()
}
return &FileStore{
basePath: basePath,
logger: logger,
}
}
// BasePath returns the base path used for storing credentials
func (s *FileStore) BasePath() string {
return s.basePath
}
// GetFilePath returns the file path for storing credentials for a specific provider.
// For multi-tenant scenarios, each provider gets a separate file based on URL hash.
func (s *FileStore) GetFilePath(providerURL string) string {
if providerURL == "" {
return s.basePath
}
// Hash provider URL for filename safety and uniqueness
hash := sha256.Sum256([]byte(providerURL))
hashStr := hex.EncodeToString(hash[:8]) // Use first 8 bytes for shorter filename
ext := filepath.Ext(s.basePath)
base := strings.TrimSuffix(s.basePath, ext)
if ext == "" {
ext = ".json"
}
return fmt.Sprintf("%s-%s%s", base, hashStr, ext)
}
// Save stores the client registration response to a file
func (s *FileStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
if creds == nil {
return fmt.Errorf("credentials cannot be nil")
}
s.mu.Lock()
defer s.mu.Unlock()
filePath := s.GetFilePath(providerURL)
// Ensure parent directory exists
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create credentials directory: %w", err)
}
data, err := json.MarshalIndent(creds, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Write with restrictive permissions (owner read/write only)
if err := os.WriteFile(filePath, data, 0600); err != nil {
return fmt.Errorf("failed to write credentials file: %w", err)
}
s.logger.Debugf("Saved client credentials to %s", filePath)
return nil
}
// Load retrieves stored credentials from a file.
// Returns nil, nil if no credentials file exists (not an error).
func (s *FileStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
filePath := s.GetFilePath(providerURL)
// #nosec G304 -- path is constructed from trusted config values via GetFilePath()
data, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // No credentials file exists - not an error
}
return nil, fmt.Errorf("failed to read credentials file: %w", err)
}
var creds ClientRegistrationResponse
if err := json.Unmarshal(data, &creds); err != nil {
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
}
s.logger.Debugf("Loaded client credentials from %s", filePath)
return &creds, nil
}
// Delete removes the credentials file for a provider
func (s *FileStore) Delete(ctx context.Context, providerURL string) error {
s.mu.Lock()
defer s.mu.Unlock()
filePath := s.GetFilePath(providerURL)
if err := os.Remove(filePath); err != nil {
if os.IsNotExist(err) {
return nil // File doesn't exist, nothing to delete
}
return fmt.Errorf("failed to remove credentials file: %w", err)
}
s.logger.Debugf("Deleted client credentials from %s", filePath)
return nil
}
// Exists checks if credentials exist for a provider
func (s *FileStore) Exists(ctx context.Context, providerURL string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
filePath := s.GetFilePath(providerURL)
_, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("failed to check credentials file: %w", err)
}
return true, nil
}
+161
View File
@@ -0,0 +1,161 @@
package dcrstorage
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"sync"
"time"
)
// Cache defines the interface for cache operations needed by RedisStore.
// This allows the main package to provide a cache implementation without
// creating circular dependencies.
type Cache interface {
// Get retrieves a value from the cache
Get(key string) (any, bool)
// Set stores a value in the cache with a TTL
Set(key string, value any, ttl time.Duration) error
// Delete removes a value from the cache
Delete(key string)
}
// RedisStore implements Store using a Cache-backed storage.
// This storage backend enables sharing DCR credentials across multiple Traefik instances
// in distributed environments (e.g., Kubernetes with multiple ingress pods).
type RedisStore struct {
cache Cache
keyPrefix string
logger Logger
mu sync.RWMutex
}
// NewRedisStore creates a new cache-backed credentials store.
// The cache should be configured with a Redis backend for distributed storage.
// If keyPrefix is empty, defaults to "dcr:creds:"
func NewRedisStore(cache Cache, keyPrefix string, logger Logger) *RedisStore {
if keyPrefix == "" {
keyPrefix = "dcr:creds:"
}
if logger == nil {
logger = NoOpLogger()
}
return &RedisStore{
cache: cache,
keyPrefix: keyPrefix,
logger: logger,
}
}
// makeKey creates a unique cache key for a provider URL.
// Uses SHA256 hash of the provider URL for consistent key generation across nodes.
func (s *RedisStore) makeKey(providerURL string) string {
if providerURL == "" {
return s.keyPrefix + "default"
}
hash := sha256.Sum256([]byte(providerURL))
return s.keyPrefix + hex.EncodeToString(hash[:])
}
// Save stores the client registration response in the cache.
// TTL is calculated based on client_secret_expires_at if available.
func (s *RedisStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
if creds == nil {
return fmt.Errorf("credentials cannot be nil")
}
s.mu.Lock()
defer s.mu.Unlock()
key := s.makeKey(providerURL)
// Calculate TTL based on client_secret_expires_at if available
ttl := 30 * 24 * time.Hour // Default: 30 days
if creds.ClientSecretExpiresAt > 0 {
expiresAt := time.Unix(creds.ClientSecretExpiresAt, 0)
ttl = time.Until(expiresAt)
if ttl < 0 {
return fmt.Errorf("credentials already expired")
}
// Add a small buffer to ensure we don't serve expired credentials
if ttl > time.Minute {
ttl -= time.Minute
}
}
// Serialize credentials to JSON for storage
data, err := json.Marshal(creds)
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Store as string in cache (will be serialized by the cache backend)
if err := s.cache.Set(key, string(data), ttl); err != nil {
return fmt.Errorf("failed to store credentials in cache: %w", err)
}
s.logger.Debugf("Saved client credentials to cache with key %s (TTL: %v)", key, ttl)
return nil
}
// Load retrieves stored credentials from the cache.
// Returns nil, nil if no credentials exist (not an error).
func (s *RedisStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
key := s.makeKey(providerURL)
value, exists := s.cache.Get(key)
if !exists {
return nil, nil // No credentials stored - not an error
}
// Handle different value types from cache
var jsonData string
switch v := value.(type) {
case string:
jsonData = v
case []byte:
jsonData = string(v)
default:
// Try to see if it's already the struct (from local cache)
if creds, ok := value.(*ClientRegistrationResponse); ok {
return creds, nil
}
return nil, fmt.Errorf("unexpected credentials type in cache: %T", value)
}
var creds ClientRegistrationResponse
if err := json.Unmarshal([]byte(jsonData), &creds); err != nil {
return nil, fmt.Errorf("failed to parse credentials from cache: %w", err)
}
s.logger.Debugf("Loaded client credentials from cache with key %s", key)
return &creds, nil
}
// Delete removes stored credentials from the cache
func (s *RedisStore) Delete(ctx context.Context, providerURL string) error {
s.mu.Lock()
defer s.mu.Unlock()
key := s.makeKey(providerURL)
s.cache.Delete(key)
s.logger.Debugf("Deleted client credentials from cache with key %s", key)
return nil
}
// Exists checks if credentials exist in the cache for a provider
func (s *RedisStore) Exists(ctx context.Context, providerURL string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
key := s.makeKey(providerURL)
_, exists := s.cache.Get(key)
return exists, nil
}
+90
View File
@@ -0,0 +1,90 @@
// Package dcrstorage provides storage backends for OIDC Dynamic Client Registration credentials.
// It supports both file-based and Redis-based storage for persisting client credentials
// across application restarts and distributed deployments.
package dcrstorage
import (
"context"
)
// StorageBackend represents the type of storage backend for DCR credentials
type StorageBackend string
const (
// StorageBackendFile uses file-based storage (default for backward compatibility)
StorageBackendFile StorageBackend = "file"
// StorageBackendRedis uses Redis for distributed storage
StorageBackendRedis StorageBackend = "redis"
// StorageBackendAuto automatically selects Redis if available, otherwise file
StorageBackendAuto StorageBackend = "auto"
)
// Logger interface for DCR storage operations
type Logger interface {
Debug(msg string)
Debugf(format string, args ...any)
Info(msg string)
Infof(format string, args ...any)
Error(msg string)
Errorf(format string, args ...any)
}
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
type ClientRegistrationResponse struct {
SubjectType string `json:"subject_type,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
Scope string `json:"scope,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
TOSURI string `json:"tos_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
ClientID string `json:"client_id"`
ClientName string `json:"client_name,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
}
// Store defines the interface for storing DCR credentials.
// This abstraction allows different storage backends (file, Redis) to be used
// for persisting OIDC Dynamic Client Registration credentials across nodes.
type Store interface {
// Save stores the client registration response for a provider
// The providerURL is used as a key to support multi-tenant scenarios
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
// Load retrieves stored credentials for a provider
// Returns nil, nil if no credentials exist (not an error)
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
// Delete removes stored credentials for a provider
Delete(ctx context.Context, providerURL string) error
// Exists checks if credentials exist for a provider
Exists(ctx context.Context, providerURL string) (bool, error)
}
// noOpLogger is a no-op implementation of Logger for default use
type noOpLogger struct{}
func (n noOpLogger) Debug(msg string) {}
func (n noOpLogger) Debugf(format string, args ...any) {}
func (n noOpLogger) Info(msg string) {}
func (n noOpLogger) Infof(format string, args ...any) {}
func (n noOpLogger) Error(msg string) {}
func (n noOpLogger) Errorf(format string, args ...any) {}
// NoOpLogger returns a no-op logger instance
func NoOpLogger() Logger {
return noOpLogger{}
}
+464
View File
@@ -0,0 +1,464 @@
package dcrstorage
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
// mockCache implements Cache for testing
type mockCache struct {
data map[string]cacheEntry
mu sync.RWMutex
}
type cacheEntry struct {
value any
expiresAt time.Time
}
func newMockCache() *mockCache {
return &mockCache{data: make(map[string]cacheEntry)}
}
func (m *mockCache) Get(key string) (any, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
entry, ok := m.data[key]
if !ok {
return nil, false
}
if time.Now().After(entry.expiresAt) {
return nil, false
}
return entry.value, true
}
func (m *mockCache) Set(key string, value any, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = cacheEntry{
value: value,
expiresAt: time.Now().Add(ttl),
}
return nil
}
func (m *mockCache) Delete(key string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
}
func TestFileStore_SaveLoad(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
testCreds := &ClientRegistrationResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "test-access-token",
RegistrationClientURI: "https://example.com/register/test-client-id",
RedirectURIs: []string{"https://app.example.com/callback"},
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "client_secret_basic",
}
ctx := context.Background()
providerURL := "https://auth.example.com"
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
tempDir2 := t.TempDir()
store2 := NewFileStore(filepath.Join(tempDir2, "nonexistent.json"), nil)
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent file: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if exists {
t.Error("Expected credentials to not exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("delete non-existent credentials", func(t *testing.T) {
err := store.Delete(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Delete should not error for non-existent: %v", err)
}
})
}
func TestFileStore_MultiProvider(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
provider1 := "https://auth1.example.com"
provider2 := "https://auth2.example.com"
creds1 := &ClientRegistrationResponse{
ClientID: "client-1",
ClientSecret: "secret-1",
}
creds2 := &ClientRegistrationResponse{
ClientID: "client-2",
ClientSecret: "secret-2",
}
if err := store.Save(ctx, provider1, creds1); err != nil {
t.Fatalf("Failed to save creds1: %v", err)
}
if err := store.Save(ctx, provider2, creds2); err != nil {
t.Fatalf("Failed to save creds2: %v", err)
}
loaded1, err := store.Load(ctx, provider1)
if err != nil {
t.Fatalf("Failed to load creds1: %v", err)
}
if loaded1.ClientID != "client-1" {
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
}
loaded2, err := store.Load(ctx, provider2)
if err != nil {
t.Fatalf("Failed to load creds2: %v", err)
}
if loaded2.ClientID != "client-2" {
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
}
if err := store.Delete(ctx, provider1); err != nil {
t.Fatalf("Failed to delete creds1: %v", err)
}
exists, _ := store.Exists(ctx, provider2)
if !exists {
t.Error("Provider 2 credentials should still exist")
}
}
func TestFileStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
creds := &ClientRegistrationResponse{
ClientID: "test-client",
ClientSecret: "test-secret",
}
var wg sync.WaitGroup
concurrency := 10
for range concurrency {
wg.Add(1)
go func() {
defer wg.Done()
_ = store.Save(ctx, providerURL, creds)
}()
}
wg.Wait()
for range concurrency {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = store.Load(ctx, providerURL)
}()
}
wg.Wait()
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load after concurrent access: %v", err)
}
if loaded == nil || loaded.ClientID != "test-client" {
t.Error("Credentials corrupted after concurrent access")
}
}
func TestFileStore_InvalidInput(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
t.Run("empty provider URL uses default path", func(t *testing.T) {
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "", creds)
if err != nil {
t.Fatalf("Save with empty provider URL failed: %v", err)
}
loaded, err := store.Load(ctx, "")
if err != nil {
t.Fatalf("Load with empty provider URL failed: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials with empty provider URL")
}
})
}
func TestFileStore_DefaultPath(t *testing.T) {
t.Parallel()
store := NewFileStore("", nil)
if store.BasePath() == "" {
t.Error("Expected default base path")
}
}
func TestRedisStore_WithMockCache(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
testCreds := &ClientRegistrationResponse{
ClientID: "redis-test-client",
ClientSecret: "redis-test-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "redis-test-token",
RedirectURIs: []string{"https://app.example.com/callback"},
}
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
}
func TestRedisStore_TTLFromExpiry(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
t.Run("expired credentials should fail", func(t *testing.T) {
expiredCreds := &ClientRegistrationResponse{
ClientID: "expired-client",
ClientSecret: "expired-secret",
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
}
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
if err == nil {
t.Error("Expected error for expired credentials")
}
})
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
creds := &ClientRegistrationResponse{
ClientID: "no-expiry-client",
ClientSecret: "no-expiry-secret",
ClientSecretExpiresAt: 0,
}
err := store.Save(ctx, "https://noexpiry.example.com", creds)
if err != nil {
t.Fatalf("Failed to save credentials without expiry: %v", err)
}
})
}
func TestRedisStore_InvalidInput(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
}
func TestFileStore_CorruptedFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
filePath := store.GetFilePath(providerURL)
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
t.Fatalf("Failed to write corrupted file: %v", err)
}
_, err := store.Load(ctx, providerURL)
if err == nil {
t.Error("Expected error for corrupted JSON")
}
}
func TestFileStore_DirectoryCreation(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
store := NewFileStore(deepPath, nil)
ctx := context.Background()
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "https://example.com", creds)
if err != nil {
t.Fatalf("Failed to save with nested directory: %v", err)
}
loaded, err := store.Load(ctx, "https://example.com")
if err != nil {
t.Fatalf("Failed to load after nested directory creation: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials from nested directory")
}
}
+3 -3
View File
@@ -12,9 +12,9 @@ import (
type FeatureFlag struct {
name string
description string
enabled atomic.Bool
mu sync.RWMutex
callbacks []func(bool)
mu sync.RWMutex
enabled atomic.Bool
}
// FeatureManager manages all feature flags in the application
@@ -173,7 +173,7 @@ func (m *FeatureManager) LoadFromEnv() {
for name, flag := range flags {
envVar := "FEATURE_" + name
if value := os.Getenv(envVar); value != "" {
enabled := strings.ToLower(value) == "true" || value == "1"
enabled := strings.EqualFold(value, "true") || value == "1"
flag.enabled.Store(enabled)
}
}
+19 -28
View File
@@ -14,50 +14,41 @@ import (
// and resource leaks. It provides centralized management of HTTP client transports with
// proper lifecycle management and security controls.
type TransportPool struct {
mu sync.RWMutex
transports map[string]*sharedTransport
maxConns int
ctx context.Context
transports map[string]*sharedTransport
cancel context.CancelFunc
clientCount int32 // Track total HTTP clients
maxClients int32 // Limit total clients
maxConns int
mu sync.RWMutex
clientCount int32
maxClients int32
}
// sharedTransport wraps an HTTP transport with reference counting
type sharedTransport struct {
transport *http.Transport
refCount int32
lastUsed time.Time
transport *http.Transport
config TransportConfig
refCount int32
}
// TransportConfig defines configuration for HTTP transports
type TransportConfig struct {
// Timeouts
DialTimeout time.Duration
TLSHandshakeTimeout time.Duration
MaxConnsPerHost int
WriteBufferSize int
ResponseHeaderTimeout time.Duration
ExpectContinueTimeout time.Duration
IdleConnTimeout time.Duration
KeepAlive time.Duration
// Connection limits
MaxIdleConns int
MaxIdleConnsPerHost int
MaxConnsPerHost int
// Features
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
// Buffer sizes
WriteBufferSize int
ReadBufferSize int
// TLS
InsecureSkipVerify bool
MinTLSVersion uint16
TLSHandshakeTimeout time.Duration
MaxIdleConns int
DialTimeout time.Duration
MaxIdleConnsPerHost int
ReadBufferSize int
MinTLSVersion uint16
ForceHTTP2 bool
DisableCompression bool
InsecureSkipVerify bool
DisableKeepAlives bool
}
var (
+1 -1
View File
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
var filteredScopes []string
for _, scope := range scopes {
if strings.ToLower(scope) != ScopeOfflineAccess {
if !strings.EqualFold(scope, ScopeOfflineAccess) {
filteredScopes = append(filteredScopes, scope)
}
}
+4 -4
View File
@@ -154,10 +154,10 @@ func TestAzureProvider_ValidateTokens(t *testing.T) {
provider := NewAzureProvider()
tests := []struct {
name string
session *mockSession
verifierError error
session *mockSession
cacheData map[string]interface{}
name string
expectedResult ValidationResult
}{
{
@@ -369,9 +369,9 @@ func TestAzureProvider_OfflineAccessHandling(t *testing.T) {
tests := []struct {
name string
inputScopes []string
expectedCount int // Expected number of offline_access scopes (should be 1)
description string
inputScopes []string
expectedCount int
}{
{
name: "No offline_access - should add one",
+5 -5
View File
@@ -8,10 +8,10 @@ import (
// Mock implementations for testing
type mockSession struct {
authenticated bool
idToken string
accessToken string
refreshToken string
authenticated bool
}
func (s *mockSession) GetIDToken() string { return s.idToken }
@@ -338,10 +338,10 @@ func TestBaseProvider_ValidateTokenExpiry(t *testing.T) {
gracePeriod := 5 * time.Minute
tests := []struct {
name string
claims map[string]interface{}
cacheFound bool
name string
expectedResult ValidationResult
cacheFound bool
}{
{
name: "Token not found in cache, has refresh token",
@@ -438,10 +438,10 @@ func TestBaseProvider_ValidateTokenExpiry_NoRefreshToken(t *testing.T) {
gracePeriod := 5 * time.Minute
tests := []struct {
name string
claims map[string]interface{}
cacheFound bool
name string
expectedResult ValidationResult
cacheFound bool
}{
{
name: "Token not found in cache, no refresh token",
+2 -2
View File
@@ -25,9 +25,9 @@ func TestProviderFactory_CreateProvider(t *testing.T) {
tests := []struct {
name string
issuerURL string
errMsg string
expectedType ProviderType
wantErr bool
errMsg string
}{
{
name: "Google provider",
@@ -158,10 +158,10 @@ func TestProviderFactory_CreateProviderByType(t *testing.T) {
tests := []struct {
name string
errMsg string
providerType ProviderType
expectedType ProviderType
wantErr bool
errMsg string
}{
{
name: "Generic provider",
+2 -2
View File
@@ -136,9 +136,9 @@ func TestGenericProvider_ValidateTokens(t *testing.T) {
provider := NewGenericProvider()
tests := []struct {
name string
session *mockSession
verifierError error
session *mockSession
name string
expectedResult ValidationResult
}{
{
+1 -1
View File
@@ -172,8 +172,8 @@ func TestGoogleProvider_OfflineAccessFiltering(t *testing.T) {
tests := []struct {
name string
inputScopes []string
description string
inputScopes []string
}{
{
name: "Multiple offline_access occurrences",
+3 -3
View File
@@ -82,9 +82,9 @@ func TestProviderRegistry_GetProviderByType(t *testing.T) {
registry.RegisterProvider(googleProvider)
tests := []struct {
expected OIDCProvider
name string
providerType ProviderType
expected OIDCProvider
}{
{
name: "Get Generic provider",
@@ -180,9 +180,9 @@ func TestProviderRegistry_DetectProvider(t *testing.T) {
registry.RegisterProvider(gitlabProvider)
tests := []struct {
expected OIDCProvider
name string
issuerURL string
expected OIDCProvider
}{
{
name: "Google provider detection",
@@ -640,9 +640,9 @@ func TestProviderRegistry_GitLabDetection_RealWorldURLs(t *testing.T) {
registry.RegisterProvider(githubProvider)
realWorldTests := []struct {
expected OIDCProvider
name string
issuerURL string
expected OIDCProvider
}{
// Actual self-hosted GitLab examples from issue #61
{
+9 -9
View File
@@ -20,8 +20,8 @@ func TestValidateIssuerURL(t *testing.T) {
tests := []struct {
name string
issuerURL string
wantErr bool
errMsg string
wantErr bool
}{
{
name: "valid https URL",
@@ -106,8 +106,8 @@ func TestValidateClientID(t *testing.T) {
tests := []struct {
name string
clientID string
wantErr bool
errMsg string
wantErr bool
}{
{
name: "valid client ID",
@@ -173,9 +173,9 @@ func TestValidateClientID(t *testing.T) {
func TestValidateScopes(t *testing.T) {
tests := []struct {
name string
errMsg string
scopes []string
wantErr bool
errMsg string
}{
{
name: "valid scopes with openid",
@@ -248,8 +248,8 @@ func TestValidateRedirectURL(t *testing.T) {
tests := []struct {
name string
redirectURL string
wantErr bool
errMsg string
wantErr bool
}{
{
name: "valid https redirect URL",
@@ -315,11 +315,11 @@ func TestValidateRedirectURL(t *testing.T) {
// TestValidateProviderSpecificConfig tests provider-specific configuration validation
func TestValidateProviderSpecificConfig(t *testing.T) {
tests := []struct {
name string
provider OIDCProvider
config map[string]interface{}
wantErr bool
name string
errMsg string
wantErr bool
}{
{
name: "valid Google config",
@@ -458,8 +458,8 @@ func TestValidateGoogleConfig_EdgeCases(t *testing.T) {
googleProvider := NewGoogleProvider()
tests := []struct {
name string
config map[string]interface{}
name string
wantErr bool
}{
{
@@ -502,10 +502,10 @@ func TestValidateAzureConfig_EdgeCases(t *testing.T) {
azureProvider := NewAzureProvider()
tests := []struct {
name string
config map[string]interface{}
wantErr bool
name string
errMsg string
wantErr bool
}{
{
name: "valid tenant ID format",
+13 -12
View File
@@ -7,9 +7,9 @@ import (
// ProviderWarning represents a warning about provider limitations or requirements.
type ProviderWarning struct {
ProviderType ProviderType
Level string // "info", "warning", "error"
Level string
Message string
ProviderType ProviderType
}
// GetProviderWarnings returns warnings about provider-specific limitations.
@@ -18,16 +18,17 @@ func GetProviderWarnings(providerType ProviderType) []ProviderWarning {
switch providerType {
case ProviderTypeGitHub:
warnings = append(warnings, ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "warning",
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
})
warnings = append(warnings, ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "info",
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
})
warnings = append(warnings,
ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "warning",
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
},
ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "info",
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
})
case ProviderTypeAuth0:
warnings = append(warnings, ProviderWarning{
+1 -1
View File
@@ -9,9 +9,9 @@ import (
func TestGetProviderWarnings(t *testing.T) {
tests := []struct {
name string
checkContent string
providerType ProviderType
expectCount int
checkContent string
}{
{
name: "GitHub has OAuth 2.0 warning",
+8 -14
View File
@@ -34,21 +34,15 @@ type Logger interface {
// for all recovery mechanism implementations. It handles request counting,
// success/failure tracking, and timestamp management in a thread-safe manner.
type BaseRecoveryMechanism struct {
// name identifies the recovery mechanism instance
name string
// logger provides structured logging capabilities
logger Logger
// Metrics tracked with atomic operations for thread safety
logger Logger
name string
lastSuccessStr string
lastFailureStr string
totalRequests int64
successCount int64
failureCount int64
lastSuccessStr string
lastFailureStr string
// mutexes for thread-safe timestamp updates
successMutex sync.RWMutex
failureMutex sync.RWMutex
successMutex sync.RWMutex
failureMutex sync.RWMutex
}
// NewBaseRecoveryMechanism creates a new base recovery mechanism with the given name and logger.
@@ -182,10 +176,10 @@ const (
// HTTPError represents an HTTP error with status code and message
type HTTPError struct {
StatusCode int
Headers map[string]string
Message string
Body []byte
Headers map[string]string
StatusCode int
}
// Error implements the error interface
+7 -13
View File
@@ -60,20 +60,14 @@ func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
// CircuitBreaker implements the circuit breaker pattern for fault tolerance.
// It prevents cascading failures by temporarily blocking requests to a failing service.
type CircuitBreaker struct {
*BaseRecoveryMechanism
config CircuitBreakerConfig
// State management
state int32 // atomic: CircuitBreakerState
lastStateChange time.Time
stateMutex sync.RWMutex
// Failure tracking
consecutiveFailures int32 // atomic
consecutiveSuccesses int32 // atomic
// Half-open state management
halfOpenRequests int32 // atomic
*BaseRecoveryMechanism
config CircuitBreakerConfig
stateMutex sync.RWMutex
state int32
consecutiveFailures int32
consecutiveSuccesses int32
halfOpenRequests int32
}
// NewCircuitBreaker creates a new circuit breaker with the given configuration
+12 -20
View File
@@ -15,20 +15,13 @@ import (
// RetryConfig defines configuration for the retry executor
type RetryConfig struct {
// MaxAttempts is the maximum number of retry attempts
MaxAttempts int
// InitialDelay is the initial delay between retries
InitialDelay time.Duration
// MaxDelay is the maximum delay between retries
MaxDelay time.Duration
// Multiplier is the backoff multiplier
Multiplier float64
// RandomizationFactor adds jitter to delays (0.0 to 1.0)
RandomizationFactor float64
// RetryableErrors defines which errors should trigger a retry
RetryableErrors []string
// RetryableStatusCodes defines which HTTP status codes should trigger a retry
RetryableErrors []string
RetryableStatusCodes []int
MaxAttempts int
InitialDelay time.Duration
MaxDelay time.Duration
Multiplier float64
RandomizationFactor float64
}
// DefaultRetryConfig returns sensible default retry configuration
@@ -46,13 +39,11 @@ func DefaultRetryConfig() RetryConfig {
// RetryExecutor implements retry logic with exponential backoff
type RetryExecutor struct {
lastRetryTime time.Time
*BaseRecoveryMechanism
config RetryConfig
// Metrics
config RetryConfig
totalRetries int64
maxRetriesHit int64
lastRetryTime time.Time
retryTimeMutex sync.RWMutex
}
@@ -125,7 +116,7 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
// Continue to next attempt
case <-ctx.Done():
re.RecordFailure()
return fmt.Errorf("retry cancelled: %w", ctx.Err())
return fmt.Errorf("retry canceled: %w", ctx.Err())
}
}
@@ -310,7 +301,7 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
}
}
allMetrics["summary"] = map[string]interface{}{
summary := map[string]interface{}{
"totalMechanisms": len(rm.mechanisms),
"totalRequests": totalRequests,
"totalSuccesses": totalSuccesses,
@@ -319,8 +310,9 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
if totalRequests > 0 {
successRate := float64(totalSuccesses) / float64(totalRequests) * 100
allMetrics["summary"].(map[string]interface{})["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
summary["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
}
allMetrics["summary"] = summary
return allMetrics
}
+10 -10
View File
@@ -223,7 +223,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelled(t *testing.T) {
wg.Wait()
if execErr == nil {
t.Error("Expected error when context is cancelled")
t.Error("Expected error when context is canceled")
}
}
@@ -240,7 +240,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelledBeforeStart(t *testing
})
if err == nil {
t.Error("Expected error when context is already cancelled")
t.Error("Expected error when context is already canceled")
}
}
@@ -273,17 +273,17 @@ func TestRetryExecutor_isRetryableError(t *testing.T) {
executor := NewRetryExecutor(config, logger)
tests := []struct {
name string
err error
name string
expected bool
}{
{"nil error", nil, false},
{"connection refused", errors.New("connection refused"), true},
{"timeout", errors.New("TIMEOUT"), true}, // case insensitive
{"EOF", errors.New("EOF"), false},
{"random error", errors.New("something else"), false},
{"context cancelled", context.Canceled, false},
{"context deadline exceeded", context.DeadlineExceeded, false},
{name: "nil error", err: nil, expected: false},
{name: "connection refused", err: errors.New("connection refused"), expected: true},
{name: "timeout", err: errors.New("TIMEOUT"), expected: true}, // case insensitive
{name: "EOF", err: errors.New("EOF"), expected: false},
{name: "random error", err: errors.New("something else"), expected: false},
{name: "context canceled", err: context.Canceled, expected: false},
{name: "context deadline exceeded", err: context.DeadlineExceeded, expected: false},
}
for _, tt := range tests {
+6 -6
View File
@@ -13,10 +13,10 @@ import (
// Mock logger for testing
type mockLogger struct {
mu sync.Mutex
logs []string
errLogs []string
debugLog []string
mu sync.Mutex
}
func (m *mockLogger) Logf(format string, args ...interface{}) {
@@ -202,13 +202,13 @@ func TestBaseRecoveryMechanism_ConcurrentAccess(t *testing.T) {
// CircuitBreakerState tests
func TestCircuitBreakerState_String(t *testing.T) {
tests := []struct {
state CircuitBreakerState
expected string
state CircuitBreakerState
}{
{CircuitBreakerClosed, "closed"},
{CircuitBreakerOpen, "open"},
{CircuitBreakerHalfOpen, "half-open"},
{CircuitBreakerState(99), "unknown"},
{state: CircuitBreakerClosed, expected: "closed"},
{state: CircuitBreakerOpen, expected: "open"},
{state: CircuitBreakerHalfOpen, expected: "half-open"},
{state: CircuitBreakerState(99), expected: "unknown"},
}
for _, tt := range tests {
+5 -5
View File
@@ -29,26 +29,26 @@ type JWK struct {
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
Scope string `json:"scope,omitempty"`
ExpiresIn int `json:"expires_in"`
}
// IntrospectionResponse represents a token introspection response
type IntrospectionResponse struct {
Active bool `json:"active"`
Scope string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
Username string `json:"username,omitempty"`
TokenType string `json:"token_type,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Nbf int64 `json:"nbf,omitempty"`
Sub string `json:"sub,omitempty"`
Aud string `json:"aud,omitempty"`
Iss string `json:"iss,omitempty"`
Jti string `json:"jti,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Nbf int64 `json:"nbf,omitempty"`
Active bool `json:"active"`
}
// JWKCache is a testify mock for JWK caching operations
+2 -2
View File
@@ -8,16 +8,16 @@ import (
// SessionData represents session data for testing
type SessionData struct {
Claims map[string]interface{}
Email string
AccessToken string
RefreshToken string
IDToken string
Expiry int64
Nonce string
State string
CodeVerifier string
RedirectURL string
Claims map[string]interface{}
Expiry int64
}
// SessionManager is a testify mock for session management
+24 -39
View File
@@ -16,45 +16,30 @@ import (
// OIDCServerConfig configures the mock OIDC server behavior
type OIDCServerConfig struct {
// Identity
Issuer string
// Discovery
ScopesSupported []string
ResponseTypesSupported []string
JWKSResponse map[string]interface{}
TokenFixture *fixtures.TokenFixture
UserinfoError *OIDCError
UserinfoResponse map[string]interface{}
IntrospectionResponse map[string]interface{}
JWKSError *OIDCError
RefreshError *OIDCError
TokenResponse map[string]interface{}
TokenError *OIDCError
IntrospectionError *OIDCError
RefreshResponse map[string]interface{}
Issuer string
GrantTypesSupported []string
ClaimsSupported []string
TokenEndpointAuthMethods []string
// Token fixture for signing
TokenFixture *fixtures.TokenFixture
// Token endpoint behavior
TokenResponse map[string]interface{}
TokenError *OIDCError
TokenDelay time.Duration
RefreshResponse map[string]interface{}
RefreshError *OIDCError
// JWKS behavior
JWKSResponse map[string]interface{}
JWKSError *OIDCError
JWKSDelay time.Duration
// Introspection behavior
IntrospectionResponse map[string]interface{}
IntrospectionError *OIDCError
// Userinfo behavior
UserinfoResponse map[string]interface{}
UserinfoError *OIDCError
// Simulation flags
SimulateTimeout bool
TimeoutDuration time.Duration
RateLimitAfter int
FailAfterN int
FailWithStatus int
ScopesSupported []string
ClaimsSupported []string
ResponseTypesSupported []string
FailAfterN int
JWKSDelay time.Duration
TimeoutDuration time.Duration
RateLimitAfter int
TokenDelay time.Duration
FailWithStatus int
SimulateTimeout bool
}
// OIDCError represents an OAuth error response
@@ -67,9 +52,9 @@ type OIDCError struct {
type OIDCServer struct {
*httptest.Server
Config *OIDCServerConfig
RequestCount int32
mu sync.Mutex
requests []*http.Request
mu sync.Mutex
RequestCount int32
}
// NewOIDCServer creates a new mock OIDC server
+4 -4
View File
@@ -135,9 +135,9 @@ func TestIsTestMode(t *testing.T) {
// We'll test what we can control via environment variables.
tests := []struct {
name string
setup func()
cleanup func()
name string
expected bool
}{
{
@@ -206,8 +206,8 @@ func TestIsTestMode(t *testing.T) {
func TestIsTestModeEdgeCases(t *testing.T) {
// Test with various environment variable combinations
tests := []struct {
name string
env map[string]string
name string
}{
{
name: "all env vars empty",
@@ -560,11 +560,11 @@ func TestIsTestModeYaegiCompiler(t *testing.T) {
// mockLogger is a simple mock implementation for testing
type mockLogger struct {
lastFormat string
lastArgs []interface{}
infoCalls int
debugCalls int
errorCalls int
lastFormat string
lastArgs []interface{}
}
func (m *mockLogger) Infof(format string, args ...interface{}) {
+12 -21
View File
@@ -21,25 +21,16 @@ import (
// JWK represents a JSON Web Key as defined in RFC 7517.
// It can represent different key types including RSA, EC, and symmetric keys.
type JWK struct {
// Key type (e.g., "RSA", "EC", "oct")
Kty string `json:"kty"`
// Key use (e.g., "sig" for signature, "enc" for encryption)
Use string `json:"use,omitempty"`
// Key operations allowed
Kty string `json:"kty"`
Use string `json:"use,omitempty"`
Alg string `json:"alg,omitempty"`
Kid string `json:"kid,omitempty"`
N string `json:"n,omitempty"`
E string `json:"e,omitempty"`
Crv string `json:"crv,omitempty"`
X string `json:"x,omitempty"`
Y string `json:"y,omitempty"`
KeyOps []string `json:"key_ops,omitempty"`
// Algorithm intended for use with this key
Alg string `json:"alg,omitempty"`
// Key ID
Kid string `json:"kid,omitempty"`
// RSA specific fields
N string `json:"n,omitempty"` // Modulus
E string `json:"e,omitempty"` // Exponent
// EC specific fields
Crv string `json:"crv,omitempty"` // Curve
X string `json:"x,omitempty"` // X coordinate
Y string `json:"y,omitempty"` // Y coordinate
}
// JWKSet represents a set of JSON Web Keys.
@@ -222,9 +213,9 @@ func (jwk *JWK) ToECDSAPublicKey() (*ecdsa.PublicKey, error) {
// GetKey finds a key by its ID (kid) in the JWKSet.
// Returns nil if no key with the given ID is found.
func (jwks *JWKSet) GetKey(kid string) *JWK {
for _, key := range jwks.Keys {
if key.Kid == kid {
return &key
for i := range jwks.Keys {
if jwks.Keys[i].Kid == kid {
return &jwks.Keys[i]
}
}
return nil
+9 -9
View File
@@ -309,18 +309,18 @@ func TestJWKCacheCleanupAndClose(t *testing.T) {
func TestFetchJWKSEdgeCases(t *testing.T) {
t.Run("handles various HTTP status codes", func(t *testing.T) {
testCases := []struct {
errContains string
status int
wantErr bool
errContains string
}{
{200, false, ""},
{400, true, "400"},
{401, true, "401"},
{403, true, "403"},
{404, true, "404"},
{500, true, "500"},
{502, true, "502"},
{503, true, "503"},
{status: 200, wantErr: false, errContains: ""},
{status: 400, wantErr: true, errContains: "400"},
{status: 401, wantErr: true, errContains: "401"},
{status: 403, wantErr: true, errContains: "403"},
{status: 404, wantErr: true, errContains: "404"},
{status: 500, wantErr: true, errContains: "500"},
{status: 502, wantErr: true, errContains: "502"},
{status: 503, wantErr: true, errContains: "503"},
}
for _, tc := range testCases {
+1 -1
View File
@@ -120,7 +120,7 @@ func getReplayCacheStats() (size int, maxSize int) {
// Parameters:
// - ctx: Parent context for cancellation
// - logger: Logger for debug output (can be nil)
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
func startReplayCacheCleanup(_ context.Context, logger *Logger) {
registry := GetGlobalTaskRegistry()
// Define the cleanup task function
+502
View File
@@ -0,0 +1,502 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik.
// This file implements OIDC Backchannel Logout (OpenID Connect Back-Channel Logout 1.0)
// and Front-Channel Logout (OpenID Connect Front-Channel Logout 1.0) functionality.
package traefikoidc
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
const (
// logoutTokenType is the expected typ claim for logout tokens
// #nosec G101 -- This is a JWT type claim value from OIDC spec, not a credential
logoutTokenType = "logout+jwt"
// sessionInvalidationTTL is how long to remember invalidated sessions
// Should be at least as long as your session max age
sessionInvalidationTTL = 25 * time.Hour
)
// LogoutTokenClaims represents the claims in an OIDC logout token
// as defined in OpenID Connect Back-Channel Logout 1.0
type LogoutTokenClaims struct {
Issuer string `json:"iss"`
Subject string `json:"sub,omitempty"`
Audience interface{} `json:"aud"` // Can be string or []string
IssuedAt int64 `json:"iat"`
JTI string `json:"jti"`
Events map[string]interface{} `json:"events"`
SessionID string `json:"sid,omitempty"`
Nonce string `json:"nonce,omitempty"` // Must NOT be present
}
// handleBackchannelLogout processes OIDC Backchannel Logout requests.
// It accepts POST requests with a logout_token parameter containing a JWT
// that identifies which session(s) to terminate.
//
// According to OpenID Connect Back-Channel Logout 1.0:
// - The logout_token is a JWT signed by the IdP
// - It contains either a 'sid' (session ID) or 'sub' (subject) claim to identify the session
// - The RP must validate the token and invalidate the matching session(s)
//
// Parameters:
// - rw: The HTTP response writer
// - req: The HTTP request containing the logout_token
func (t *TraefikOidc) handleBackchannelLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing backchannel logout request")
// Backchannel logout must be POST
if req.Method != http.MethodPost {
t.logger.Errorf("Backchannel logout: invalid method %s, expected POST", req.Method)
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse form data to get logout_token
if err := req.ParseForm(); err != nil {
t.logger.Errorf("Backchannel logout: failed to parse form: %v", err)
http.Error(rw, "Bad request", http.StatusBadRequest)
return
}
logoutToken := req.FormValue("logout_token")
if logoutToken == "" {
// Also try reading from request body as raw JWT
body, err := io.ReadAll(io.LimitReader(req.Body, 64*1024)) // 64KB limit
if err == nil && len(body) > 0 {
logoutToken = string(body)
}
}
if logoutToken == "" {
t.logger.Error("Backchannel logout: missing logout_token")
http.Error(rw, "logout_token required", http.StatusBadRequest)
return
}
// Parse and validate the logout token
claims, err := t.validateLogoutToken(logoutToken)
if err != nil {
t.logger.Errorf("Backchannel logout: token validation failed: %v", err)
// Return 400 for invalid token per spec
http.Error(rw, "Invalid logout token", http.StatusBadRequest)
return
}
// Invalidate session(s) based on sid or sub
if err := t.invalidateSession(claims.SessionID, claims.Subject); err != nil {
t.logger.Errorf("Backchannel logout: failed to invalidate session: %v", err)
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
return
}
t.logger.Infof("Backchannel logout: successfully invalidated session (sid=%s, sub=%s)",
claims.SessionID, claims.Subject)
// Return 200 OK with empty body per spec
rw.WriteHeader(http.StatusOK)
}
// handleFrontchannelLogout processes OIDC Front-Channel Logout requests.
// It accepts GET requests with 'iss' and 'sid' query parameters that identify
// which session to terminate. The IdP typically loads this URL in an iframe.
//
// According to OpenID Connect Front-Channel Logout 1.0:
// - The request contains 'iss' (issuer) and optionally 'sid' (session ID)
// - The RP should clear the session and return a response (typically empty or image)
// - The response must be cacheable to allow the IdP to load it in an iframe
//
// Parameters:
// - rw: The HTTP response writer
// - req: The HTTP request containing iss and sid parameters
func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing front-channel logout request")
// Front-channel logout should be GET
if req.Method != http.MethodGet {
t.logger.Errorf("Front-channel logout: invalid method %s, expected GET", req.Method)
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get iss and sid from query parameters
iss := req.URL.Query().Get("iss")
sid := req.URL.Query().Get("sid")
// Validate issuer matches our expected issuer
t.metadataMu.RLock()
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
if iss != "" && iss != expectedIssuer {
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
return
}
// Must have at least sid for front-channel logout
if sid == "" {
t.logger.Error("Front-channel logout: missing sid parameter")
http.Error(rw, "sid parameter required", http.StatusBadRequest)
return
}
// Invalidate the session
if err := t.invalidateSession(sid, ""); err != nil {
t.logger.Errorf("Front-channel logout: failed to invalidate session: %v", err)
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
return
}
t.logger.Infof("Front-channel logout: successfully invalidated session (sid=%s)", sid)
// Return a minimal HTML response that's suitable for iframe loading
// Set headers to allow embedding and caching
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Cache-Control", "no-cache, no-store")
rw.Header().Set("Pragma", "no-cache")
// Allow embedding in iframes from any origin (required for front-channel logout)
rw.Header().Del("X-Frame-Options")
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("<!DOCTYPE html><html><head><title>Logged Out</title></head><body></body></html>"))
}
// validateLogoutToken parses and validates a logout token JWT.
// It verifies the token signature, issuer, audience, and required claims.
//
// Parameters:
// - tokenString: The raw JWT logout token
//
// Returns:
// - The parsed logout token claims
// - An error if validation fails
func (t *TraefikOidc) validateLogoutToken(tokenString string) (*LogoutTokenClaims, error) {
// Parse the JWT
jwt, err := parseJWT(tokenString)
if err != nil {
return nil, fmt.Errorf("failed to parse logout token: %w", err)
}
// Check token type if present
if typ, ok := jwt.Header["typ"].(string); ok {
// The typ should be "logout+jwt" or omitted
if typ != "" && typ != logoutTokenType && typ != "JWT" {
return nil, fmt.Errorf("invalid token type: %s", typ)
}
}
// Verify signature only (not standard claims - logout tokens don't have 'exp')
if err := t.verifyLogoutTokenSignature(jwt, tokenString); err != nil {
return nil, fmt.Errorf("signature verification failed: %w", err)
}
// Extract claims
claims := &LogoutTokenClaims{}
claimsJSON, err := json.Marshal(jwt.Claims)
if err != nil {
return nil, fmt.Errorf("failed to marshal claims: %w", err)
}
if err := json.Unmarshal(claimsJSON, claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
// Validate required claims
t.metadataMu.RLock()
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
// Validate issuer
if claims.Issuer != expectedIssuer {
return nil, fmt.Errorf("issuer mismatch: got %s, expected %s", claims.Issuer, expectedIssuer)
}
// Validate audience (must contain our client_id)
if !t.validateLogoutTokenAudience(claims.Audience) {
return nil, fmt.Errorf("audience validation failed")
}
// Validate iat (issued at) - must be present and not too old
if claims.IssuedAt == 0 {
return nil, fmt.Errorf("missing iat claim")
}
iatTime := time.Unix(claims.IssuedAt, 0)
// Allow up to 5 minutes clock skew and 10 minutes token age
if time.Since(iatTime) > 15*time.Minute {
return nil, fmt.Errorf("logout token too old: issued at %v", iatTime)
}
// Token should not be from the future (with 5 min clock skew tolerance)
if iatTime.After(time.Now().Add(5 * time.Minute)) {
return nil, fmt.Errorf("logout token issued in the future: %v", iatTime)
}
// Validate events claim - must contain the logout event
if claims.Events == nil {
return nil, fmt.Errorf("missing events claim")
}
if _, ok := claims.Events["http://schemas.openid.net/event/backchannel-logout"]; !ok {
return nil, fmt.Errorf("missing backchannel-logout event in events claim")
}
// Validate that nonce is NOT present (per spec)
if claims.Nonce != "" {
return nil, fmt.Errorf("nonce claim must not be present in logout token")
}
// Must have either sid or sub (or both)
if claims.SessionID == "" && claims.Subject == "" {
return nil, fmt.Errorf("logout token must contain either sid or sub claim")
}
return claims, nil
}
// validateLogoutTokenAudience checks if the logout token audience contains our client_id
func (t *TraefikOidc) validateLogoutTokenAudience(aud interface{}) bool {
switch v := aud.(type) {
case string:
return v == t.clientID
case []interface{}:
for _, a := range v {
if s, ok := a.(string); ok && s == t.clientID {
return true
}
}
case []string:
for _, a := range v {
if a == t.clientID {
return true
}
}
}
return false
}
// verifyLogoutTokenSignature verifies only the signature of a logout token.
// Unlike VerifyJWTSignatureAndClaims, this does NOT validate standard claims like 'exp'
// because logout tokens don't have an expiration claim per OIDC Back-Channel Logout spec.
//
// Parameters:
// - jwt: The parsed JWT structure
// - tokenString: The raw token string for signature verification
//
// Returns:
// - An error if signature verification fails
func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) error {
t.logger.Debug("Verifying logout token signature")
// Read jwksURL with RLock
t.metadataMu.RLock()
jwksURL := t.jwksURL
t.metadataMu.RUnlock()
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
if jwks == nil {
return fmt.Errorf("JWKS is nil, cannot verify token")
}
kid, ok := jwt.Header["kid"].(string)
if !ok || kid == "" {
return fmt.Errorf("missing key ID in token header")
}
alg, ok := jwt.Header["alg"].(string)
if !ok || alg == "" {
return fmt.Errorf("missing algorithm in token header")
}
// Find the matching key in JWKS
var matchingKey *JWK
for i := range jwks.Keys {
if jwks.Keys[i].Kid == kid {
matchingKey = &jwks.Keys[i]
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
if err := verifySignature(tokenString, publicKeyPEM, alg); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Logout token signature verified successfully")
return nil
}
// invalidateSession marks a session as invalidated in the session invalidation cache.
// It stores entries by both sid and sub if available.
//
// Parameters:
// - sid: The session ID to invalidate (from the 'sid' claim)
// - sub: The subject to invalidate (from the 'sub' claim)
//
// Returns:
// - An error if the invalidation fails
func (t *TraefikOidc) invalidateSession(sid, sub string) error {
if t.sessionInvalidationCache == nil {
return fmt.Errorf("session invalidation cache not initialized")
}
now := time.Now().Unix()
// Store by session ID
if sid != "" {
key := t.buildSessionInvalidationKey("sid", sid)
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
t.logger.Debugf("Invalidated session by sid: %s", sid)
}
// Store by subject (invalidates all sessions for this user)
if sub != "" {
key := t.buildSessionInvalidationKey("sub", sub)
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
t.logger.Debugf("Invalidated session by sub: %s", sub)
}
return nil
}
// isSessionInvalidated checks if a session has been invalidated via backchannel
// or front-channel logout.
//
// Parameters:
// - sid: The session ID to check
// - sub: The subject to check
// - sessionCreatedAt: When the session was created (to compare against invalidation time)
//
// Returns:
// - true if the session has been invalidated, false otherwise
func (t *TraefikOidc) isSessionInvalidated(sid, sub string, sessionCreatedAt time.Time) bool {
if t.sessionInvalidationCache == nil {
return false
}
// Truncate session creation time to seconds for fair comparison with Unix timestamps
sessionCreatedAtSec := sessionCreatedAt.Truncate(time.Second)
// Check by session ID first (more specific)
if sid != "" {
key := t.buildSessionInvalidationKey("sid", sid)
if val, found := t.sessionInvalidationCache.Get(key); found {
if invalidatedAt, ok := val.(int64); ok {
// Session was invalidated at or after it was created
invalidationTime := time.Unix(invalidatedAt, 0)
if !invalidationTime.Before(sessionCreatedAtSec) {
t.logger.Debugf("Session invalidated by sid: %s", sid)
return true
}
}
}
}
// Check by subject (all sessions for this user)
if sub != "" {
key := t.buildSessionInvalidationKey("sub", sub)
if val, found := t.sessionInvalidationCache.Get(key); found {
if invalidatedAt, ok := val.(int64); ok {
// Sessions for this subject created at or before invalidation are invalid
invalidationTime := time.Unix(invalidatedAt, 0)
if !invalidationTime.Before(sessionCreatedAtSec) {
t.logger.Debugf("Session invalidated by sub: %s", sub)
return true
}
}
}
}
return false
}
// buildSessionInvalidationKey creates a cache key for session invalidation
func (t *TraefikOidc) buildSessionInvalidationKey(keyType, value string) string {
return fmt.Sprintf("session_invalidation:%s:%s", keyType, value)
}
// extractSessionInfo extracts sid and sub from an ID token for session tracking
func (t *TraefikOidc) extractSessionInfo(idToken string) (sid, sub string, createdAt time.Time) {
if idToken == "" {
return "", "", time.Time{}
}
jwt, err := parseJWT(idToken)
if err != nil {
return "", "", time.Time{}
}
// Extract sid (session ID)
if sidVal, ok := jwt.Claims["sid"].(string); ok {
sid = sidVal
}
// Extract sub (subject)
if subVal, ok := jwt.Claims["sub"].(string); ok {
sub = subVal
}
// Extract iat for session creation time
if iatVal, ok := jwt.Claims["iat"].(float64); ok {
createdAt = time.Unix(int64(iatVal), 0)
} else {
// Default to now if iat not present
createdAt = time.Now()
}
return sid, sub, createdAt
}
// determineLogoutPath checks if the given path matches any logout URL
func (t *TraefikOidc) determineLogoutPath(path string) string {
// Check backchannel logout path
if t.backchannelLogoutPath != "" && path == t.backchannelLogoutPath {
return "backchannel"
}
// Check front-channel logout path
if t.frontchannelLogoutPath != "" && path == t.frontchannelLogoutPath {
return "frontchannel"
}
// Check regular logout path (for RP-initiated logout)
if path == t.logoutURLPath {
return "rp"
}
return ""
}
// normalizeLogoutPath ensures logout paths start with / and prevents open redirects
func normalizeLogoutPath(path string) string {
if path == "" {
return ""
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// Prevent open redirect: ensure second character is not / or \
// This prevents URLs like //example.com or /\example.com from being treated as absolute URLs
if len(path) > 1 && (path[1] == '/' || path[1] == '\\') {
// Strip leading slashes/backslashes and re-normalize
path = strings.TrimLeft(path, "/\\")
if path != "" {
path = "/" + path
}
}
return path
}
+1623
View File
File diff suppressed because it is too large Load Diff
+35 -10
View File
@@ -212,16 +212,22 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}
return 60 * time.Second
}(),
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
stripAuthCookies: config.StripAuthCookies,
enableBackchannelLogout: config.EnableBackchannelLogout,
enableFrontchannelLogout: config.EnableFrontchannelLogout,
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
}
// Log audience configuration
@@ -298,6 +304,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
// Log callback URL configuration to help diagnose redirect loop issues.
// If callbackURL is a full URL instead of a path, the callback matching
// in ServeHTTP will silently fail because req.URL.Path is compared directly.
logger.Debugf("TraefikOidc.New: callbackURL (redirURLPath) configured as: %q", t.redirURLPath)
logger.Debugf("TraefikOidc.New: logoutURLPath configured as: %q", t.logoutURLPath)
t.providerURL = config.ProviderURL
// Use singleton resource manager for metadata initialization
@@ -433,6 +445,19 @@ func (t *TraefikOidc) performDynamicClientRegistration() {
t.dcrConfig,
t.providerURL,
)
// Set up storage backend for credentials persistence
if t.dcrConfig.PersistCredentials {
cacheManager := GetGlobalCacheManagerWithConfig(t.goroutineWG, nil)
store, err := NewDCRCredentialsStore(t.dcrConfig, cacheManager, t.logger)
if err != nil {
t.logger.Errorf("Failed to create DCR credentials store: %v", err)
// Continue without persistence - registration will still work
} else {
t.dynamicClientRegistrar.SetStore(store)
t.logger.Debugf("DCR credentials store initialized with backend: %s", t.dcrConfig.StorageBackend)
}
}
}
// Get registration endpoint (from metadata or config override)
+1 -1
View File
@@ -9,7 +9,7 @@ import (
"gopkg.in/yaml.v3"
)
// Config Marshalling Tests
// Config Marshaling Tests
func TestConfig_MarshalJSON(t *testing.T) {
config := &Config{
+3 -3
View File
@@ -18,15 +18,15 @@ import (
// TestExchangeCodeForToken_Comprehensive tests the ExchangeCodeForToken function comprehensively
func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
tests := []struct {
setupMock func(*httptest.Server) *TraefikOidc
validateFunc func(*testing.T, *TokenResponse, error)
name string
grantType string
code string
redirectURL string
codeVerifier string
setupMock func(*httptest.Server) *TraefikOidc
validateFunc func(*testing.T, *TokenResponse, error)
wantErr bool
expectedError string
wantErr bool
}{
{
name: "successful authorization code exchange",
+2 -2
View File
@@ -13,9 +13,9 @@ import (
func TestGoroutineLeakPrevention_ContextCancellation(t *testing.T) {
tests := []struct {
name string
cancelAfter time.Duration
expectedLeaks int // Maximum expected goroutines after cleanup
description string
cancelAfter time.Duration
expectedLeaks int
}{
{
name: "immediate_cancellation",
+2 -2
View File
@@ -15,10 +15,10 @@ import (
// TestInitializeMetadata tests the initializeMetadata function
func TestInitializeMetadata(t *testing.T) {
tests := []struct {
name string
providerURL string
setupMock func() *httptest.Server
validateFunc func(*testing.T, *TraefikOidc)
name string
providerURL string
wantPanic bool
}{
{
+3 -3
View File
@@ -16,12 +16,12 @@ import (
// TestGetNewTokenWithRefreshToken tests the GetNewTokenWithRefreshToken function
func TestGetNewTokenWithRefreshToken(t *testing.T) {
tests := []struct {
name string
refreshToken string
setupMock func(*httptest.Server) *TraefikOidc
validateFunc func(*testing.T, *TokenResponse, error)
wantErr bool
name string
refreshToken string
expectedError string
wantErr bool
}{
{
name: "successful token refresh",
+345 -4
View File
@@ -10,9 +10,9 @@ import (
// TestServeHTTP_ExcludedURLs tests the excluded URLs functionality
func TestServeHTTP_ExcludedURLs(t *testing.T) {
tests := []struct {
excludedURLs map[string]struct{}
name string
path string
excludedURLs map[string]struct{}
shouldBypass bool
}{
{
@@ -506,12 +506,12 @@ type MockSessionData struct {
idToken string
accessToken string
refreshToken string
authenticated bool
isDirty bool
redirectCount int
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
}
func (m *MockSessionData) GetEmail() string { return m.email }
@@ -710,3 +710,344 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
}
}
// TestStripAuthCookies tests the stripAuthCookies configuration option.
// This addresses GitHub issue #122 - OIDC cookies bloating backend requests.
func TestStripAuthCookies(t *testing.T) {
tests := []struct {
name string
stripAuthCookies bool
expectOIDCCookies bool
expectAppCookies bool
}{
{
name: "stripAuthCookies=false (default) forwards all cookies",
stripAuthCookies: false,
expectOIDCCookies: true,
expectAppCookies: true,
},
{
name: "stripAuthCookies=true strips OIDC cookies but keeps app cookies",
stripAuthCookies: true,
expectOIDCCookies: false,
expectAppCookies: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: tt.stripAuthCookies,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
// Get a valid session first (before adding fake cookies)
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetAuthenticated(true)
// Now add OIDC session cookies (simulating what the browser would send)
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_1", Value: "chunk1"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "r", Value: "refresh-token"})
// Add non-OIDC application cookies (these must always pass through)
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "app-session-id"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Check for OIDC cookies in captured cookies
hasOIDCCookie := false
hasAppSession := false
hasTheme := false
for _, c := range capturedCookies {
if len(c.Name) >= len(cookiePrefix) && c.Name[:len(cookiePrefix)] == cookiePrefix {
hasOIDCCookie = true
}
if c.Name == "my_app_session" {
hasAppSession = true
}
if c.Name == "theme" {
hasTheme = true
}
}
if tt.expectOIDCCookies && !hasOIDCCookie {
t.Error("expected OIDC cookies to be forwarded to backend")
}
if !tt.expectOIDCCookies && hasOIDCCookie {
t.Error("expected OIDC cookies to be stripped before forwarding to backend")
}
if tt.expectAppCookies && !hasAppSession {
t.Error("expected my_app_session cookie to be forwarded to backend")
}
if tt.expectAppCookies && !hasTheme {
t.Error("expected theme cookie to be forwarded to backend")
}
})
}
}
// TestStripAuthCookies_NoCookies verifies stripping works when the request has no cookies.
func TestStripAuthCookies_NoCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected no cookies, got %d", len(capturedCookies))
}
}
// TestStripAuthCookies_OnlyOIDCCookies verifies that when all cookies are OIDC cookies,
// the Cookie header is empty after stripping.
func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
var capturedCookieHeader string
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookieHeader = r.Header.Get("Cookie")
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetAuthenticated(true)
// Add only OIDC cookies
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected all cookies to be stripped, got %d", len(capturedCookies))
}
if capturedCookieHeader != "" {
t.Errorf("expected empty Cookie header, got %q", capturedCookieHeader)
}
}
// TestStripAuthCookies_OnlyAppCookies verifies that non-OIDC cookies pass through
// untouched when stripping is enabled.
func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetAuthenticated(true)
// Add only non-OIDC cookies
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "abc123"})
req.AddCookie(&http.Cookie{Name: "lang", Value: "en"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 3 {
t.Errorf("expected 3 cookies, got %d", len(capturedCookies))
}
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
for _, expected := range []string{"my_app_session", "lang", "theme"} {
if !cookieNames[expected] {
t.Errorf("expected cookie %q to be forwarded", expected)
}
}
}
// TestStripAuthCookies_CustomPrefix verifies stripping works with a custom cookie prefix.
func TestStripAuthCookies_CustomPrefix(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
// Create session manager with custom prefix
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "myapp_oidc_", 0, NewLogger("debug"))
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
customPrefix := sm.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetAuthenticated(true)
// Add cookies with the custom prefix (should be stripped)
req.AddCookie(&http.Cookie{Name: customPrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: customPrefix + "s_0", Value: "chunk0"})
// Add default-prefix cookie (should NOT be stripped — different prefix)
req.AddCookie(&http.Cookie{Name: "_oidc_raczylo_m", Value: "other-session"})
// Add app cookie (should NOT be stripped)
req.AddCookie(&http.Cookie{Name: "my_app", Value: "val"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
// Custom prefix cookies should be stripped
if cookieNames[customPrefix+"m"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"m")
}
if cookieNames[customPrefix+"s_0"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"s_0")
}
// Default prefix cookie should pass through (different prefix)
if !cookieNames["_oidc_raczylo_m"] {
t.Error("expected _oidc_raczylo_m cookie to pass through (different prefix)")
}
// App cookie should pass through
if !cookieNames["my_app"] {
t.Error("expected my_app cookie to pass through")
}
}
+2 -2
View File
@@ -81,11 +81,11 @@ func TestIsTestMode_DefaultBehavior(t *testing.T) {
// TestVerifyAudience tests the verifyAudience function
func TestVerifyAudience(t *testing.T) {
tests := []struct {
name string
tokenAudience interface{}
name string
expectedAudience string
expectError bool
description string
expectError bool
}{
{
name: "Audience matches",

Some files were not shown because too many files have changed in this diff Show More