Compare commits

...

5 Commits

Author SHA1 Message Date
lukaszraczylo 9d52f1b018 feat(core): refactor linters config and improve code quality (#119)
- [x] Reorganize golangci-lint configuration with documented disable reasons
- [x] Simplify errcheck and revive linter rules with targeted exclusions
- [x] Pre-compile regex patterns in input_validation.go for performance
- [x] Fix type assertions in memory_shard.go and resp.go with safety checks
- [x] Replace string comparison with EqualFold for case-insensitive matching
- [x] Fix loop variable captures in jwk.go and logout.go
- [x] Change high goroutine log level from Info to Debug in autocleanup.go
- [x] Replace deprecated "cancelled" spelling with "canceled" throughout
- [x] Add nolint annotations for intentional unused parameters
- [x] Improve comment formatting for deprecated functions
- [x] Fix comment spelling: "marshalling" → "marshaling"
- [x] Refactor provider warnings formatting in internal/providers/warnings.go
- [x] Simplify metrics summary building in internal/recovery/metrics.go
- [x] Pre-allocate slice in error_recovery.go GetDegradedServices
- [x] Refactor context cancellation checks in redis.go
2026-01-15 10:40:49 +00:00
lukaszraczylo 57724918fe fix 116 (#118)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases

* docs: add security fix documentation for integer overflow protection

* test: fix goroutine tests to use mock OIDC servers

The TestContextAwareGoroutineManagement tests were making real HTTP
calls to hardcoded URLs like https://example.com, causing failures
in CI when those requests timeout or return HTTP errors.

Changes:
- Added createMockOIDCServer() helper function using httptest
- Updated GoroutineCleanupOnContextCancel to use mock server
- Updated NoGoroutineLeakOnMultipleInstances to use 3 mock servers
- Updated SingletonTasksAcrossInstances to use mock servers array

This prevents network calls and makes tests more reliable and faster.

Fixes test failures in GitHub Actions CI.
2026-01-08 22:50:46 +00:00
lukaszraczylo 775de2ada1 Fix cache serialisation (#117)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases
2026-01-08 22:06:19 +00:00
lukaszraczylo 7816e05c98 fix issue with logout url (#112)
* fix(logout): handle logout requests before OIDC initialization

- [x] Add debug logging to logout handler entry point
- [x] Move logout path check before OIDC initialization to enable logout when provider unavailable
- [x] Move excluded URL and SSE checks before initialization wait
- [x] Add debug logging for initialization wait to diagnose hanging requests
- [x] Add test for logout functionality without OIDC provider availability

* feat(logout): implement OIDC backchannel and front-channel logout

- [x] Add logout token validation and backchannel logout handler
- [x] Add front-channel logout handler with iframe support
- [x] Implement session invalidation cache for distributed deployments
- [x] Add comprehensive logout token claim verification (issuer, audience, events, iat, sid/sub)
- [x] Integrate session invalidation checks into authorization flow
- [x] Add configuration options for enabling backchannel/front-channel logout
- [x] Add extensive test coverage for logout flows and edge cases
- [x] Update documentation with logout configuration examples
- [x] Add middleware routing for logout endpoints
- [x] Extend cache manager with session invalidation cache support

Resolves #110

* fixup! feat(logout): implement OIDC backchannel and front-channel logout

* fixup! Merge branch 'main' into fix-issue-with-logout-url
2026-01-04 01:59:50 +00:00
Dominik Chilla 8bf7998150 Fix for Hashicorp Vault - accept opaque access tokens with dot-characters (#113) 2026-01-02 16:42:22 +00:00
44 changed files with 3310 additions and 333 deletions
+1
View File
@@ -1,3 +1,4 @@
docker/
.claude/*.out
*.test
.leann/
+49 -32
View File
@@ -14,21 +14,22 @@ linters:
- gosec
- misspell
- noctx
- nolintlint
- prealloc
- revive
- rowserrcheck
- sqlclosecheck
- unconvert
- unparam
- whitespace
disable:
- exhaustive
- funlen
- gocognit
- gocyclo # Disabled: OAuth/OIDC flows are inherently complex
- goprintffuncname # Disabled: naming convention is project-specific
- lll
- mnd
- testpackage
- whitespace # Disabled: style preference about newlines
- wsl
settings:
dupl:
@@ -47,29 +48,13 @@ linters:
- fmt.Fprintln
goconst:
min-len: 3
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
min-occurrences: 15 # Increased to reduce noise for standard OAuth2/OIDC strings and common patterns like "true"
ignore-tests: true
gocritic:
# Using default enabled checks in v2
enabled-checks:
- appendCombine
- boolExprSimplify
- builtinShadow
- commentedOutCode
- emptyFallthrough
- equalFold
- hexLiteral
- indexAlloc
- initClause
- methodExprCall
- nestingReduce
- rangeExprCopy
- rangeValCopy
- stringXbytes
- typeAssertChain
- typeUnparen
- unlabelStmt
- yodaStyleExpr
# Disable style-only checks that add noise
disabled-checks:
- ifElseChain # Style preference, switch not always clearer
- elseif # Style preference
gocyclo:
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
gosec:
@@ -106,23 +91,23 @@ linters:
- name: error-return
- name: error-strings
- name: error-naming
- name: exported
- name: if-return
# - name: exported # Disabled: too noisy, not all exported functions need comments
# - name: if-return # Disabled: style preference
- name: increment-decrement
- name: var-naming
- name: var-declaration
- name: package-comments
# - name: var-naming # Disabled: too strict for legacy code (IP vs Ip)
# - name: var-declaration # Disabled: explicit zero values can be clearer
# - name: package-comments # Disabled: handled by other tools
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
- name: indent-error-flow
# - name: indent-error-flow # Disabled: style preference
- name: errorf
- name: empty-block
# - name: empty-block # Disabled: sometimes empty blocks are intentional
- name: superfluous-else
- name: unused-parameter
# - name: unused-parameter # Disabled: test callbacks and interface implementations often have required unused params
- name: unreachable-code
- name: redefines-builtin-id
# - name: redefines-builtin-id # Disabled: min/max helpers are common before Go 1.21
unparam:
check-exported: false
staticcheck:
@@ -132,8 +117,15 @@ linters:
- -QF1003 # Tagged switch - style preference, may affect Yaegi
- -QF1007 # Merge conditional assignment - style preference
- -QF1008 # Remove embedded field - may break Yaegi compatibility
- -QF1011 # Omit type from declaration - style preference
- -QF1012 # Use fmt.Fprintf - style preference
- -SA9003 # Empty branch - sometimes intentional for future work
- -ST1000 # Package comment format - not required for all packages
- -ST1003 # Package name format - allowed for test packages
- -ST1016 # Receiver name consistency - legacy code
- -ST1020 # Comment format for methods - style preference
- -ST1021 # Comment format for types - style preference
- -ST1023 # Omit type from declaration - style preference
exclusions:
generated: lax
rules:
@@ -144,18 +136,43 @@ linters:
- goconst
- gocyclo
- gosec
- govet
- ineffassign
- noctx
- prealloc
- unparam
- revive
- gocritic
path: _test\.go
- linters:
- dupl
- gocyclo
- govet
- noctx
- prealloc
- unparam
- revive
- gocritic
path: test.*\.go
- linters:
- gocritic
- unused
- errcheck
- revive
path: mocks.*\.go
- linters:
- errcheck
- revive
- gocritic
- govet
- unparam
path: internal/testutil/
- linters:
- govet
- unparam
- noctx
- prealloc
path: integration/
- linters:
- gosec
text: 'G404:'
+73
View File
@@ -1021,6 +1021,79 @@ configuration:
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
required: false
enableBackchannelLogout:
type: boolean
description: |
Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST).
When enabled, the middleware accepts logout tokens at the configured backchannelLogoutURL.
The IdP sends a signed JWT (logout_token) to notify the application that a user's session
should be terminated.
This implements the OIDC Back-Channel Logout 1.0 specification.
See: https://openid.net/specs/openid-connect-backchannel-1_0.html
Requirements:
- backchannelLogoutURL must be configured
- The IdP must be configured to send logout tokens to your backchannel URL
- Logout tokens are validated using the IdP's JWKS
Default: false
required: false
backchannelLogoutURL:
type: string
description: |
Path for receiving backchannel logout tokens from the IdP.
This endpoint receives POST requests with a logout_token JWT in the request body.
The token is validated against the IdP's JWKS and contains the session ID (sid)
and/or subject (sub) to invalidate.
Example: /backchannel-logout
The full URL to configure in your IdP would be:
https://your-app.example.com/backchannel-logout
Note: This path should be unique and not conflict with your application routes.
required: false
enableFrontchannelLogout:
type: boolean
description: |
Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe).
When enabled, the middleware accepts logout requests at the configured frontchannelLogoutURL.
The IdP embeds an iframe pointing to this URL when the user logs out, allowing the
application to clear the user's session.
This implements the OIDC Front-Channel Logout 1.0 specification.
See: https://openid.net/specs/openid-connect-frontchannel-1_0.html
Requirements:
- frontchannelLogoutURL must be configured
- The IdP must be configured with your front-channel logout URL
- Your CSP headers must allow being embedded in an iframe from the IdP
Default: false
required: false
frontchannelLogoutURL:
type: string
description: |
Path for receiving front-channel logout requests from the IdP.
This endpoint receives GET requests with optional sid (session ID) and iss (issuer)
query parameters. When called, it invalidates the user's session.
Example: /frontchannel-logout
The full URL to configure in your IdP would be:
https://your-app.example.com/frontchannel-logout
Note: This path should be unique and not conflict with your application routes.
required: false
headers:
type: array
description: |
+48
View File
@@ -154,6 +154,10 @@ The middleware supports the following configuration options:
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
| `enableBackchannelLogout` | Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST) | `false` | `true` |
| `backchannelLogoutURL` | The path for receiving backchannel logout tokens from the IdP | none | `/backchannel-logout` |
| `enableFrontchannelLogout` | Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe) | `false` | `true` |
| `frontchannelLogoutURL` | The path for receiving front-channel logout requests from the IdP | none | `/frontchannel-logout` |
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
@@ -1148,6 +1152,50 @@ spec:
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
```
### With IdP-Initiated Logout (Backchannel & Front-Channel)
This plugin supports [OIDC Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) and [OIDC Front-Channel Logout](https://openid.net/specs/openid-connect-frontchannel-1_0.html) for IdP-initiated single logout.
**Backchannel Logout** (recommended): The IdP sends a server-to-server POST request with a signed `logout_token` JWT when a user logs out.
**Front-Channel Logout**: The IdP loads an iframe with the logout URL to invalidate the session in the browser.
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-idp-logout
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://auth.example.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout # RP-initiated logout
# Backchannel Logout (server-to-server)
enableBackchannelLogout: true
backchannelLogoutURL: /backchannel-logout
# Front-Channel Logout (iframe-based)
enableFrontchannelLogout: true
frontchannelLogoutURL: /frontchannel-logout
# For multi-replica deployments, use Redis to share session invalidations
redis:
enabled: true
address: redis:6379
```
> **Note**: For multi-replica deployments, you **must** enable Redis to share session invalidation state across all instances. Otherwise, a logout on one instance won't invalidate sessions on other instances.
**IdP Configuration**: Configure your IdP to send logout requests to:
- **Backchannel**: `https://your-app.example.com/backchannel-logout` (POST with `logout_token`)
- **Front-Channel**: `https://your-app.example.com/frontchannel-logout?sid=SESSION_ID&iss=ISSUER` (GET in iframe)
### With Templated Headers
```yaml
+49
View File
@@ -0,0 +1,49 @@
# Security Fix: Integer Overflow Protection in Cache Serialization
## Summary
Fixed **High severity** integer overflow vulnerability identified by GitHub Advanced Security in PR #117.
## Vulnerability
**Locations**: `universal_cache.go` lines 789 and 811
- `result := make([]byte, len(bytes)+1)` - Raw bytes path
- `result := make([]byte, len(jsonData)+1)` - JSON encoding path
**Risk**: Potential integer overflow when allocating memory for very large cache entries.
## Fix Applied
1. **Added size limit constant**:
```go
maxCacheEntrySize = 64 * 1024 * 1024 // 64 MiB
```
2. **Size validation before allocation**:
- Validates entry size doesn't exceed limit
- Validates adding marker byte won't overflow
- Returns descriptive error messages
3. **Comprehensive test coverage**:
- Oversized byte slices (>64 MiB)
- Exact max size edge case
- Safe sizes (normal operation)
- Large JSON data structures
## Verification
✅ All tests pass with race detection
✅ No security issues (golangci-lint, gosec)
✅ 76.3% test coverage maintained
## Impact
- No breaking changes
- Negligible performance overhead
- Prevents potential buffer overflows
- Predictable memory usage
---
**Date**: January 8, 2026
**Severity**: High → Resolved
+4 -3
View File
@@ -599,8 +599,9 @@ func GetGlobalTaskMemoryMonitor(logger *Logger) *TaskMemoryMonitor {
return globalTaskMemoryMonitor
}
// NewTaskMemoryMonitor creates a new memory monitor for task registry
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior
// NewTaskMemoryMonitor creates a new memory monitor for task registry.
//
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior.
func NewTaskMemoryMonitor(logger *Logger, registry *TaskRegistry) *TaskMemoryMonitor {
return GetGlobalTaskMemoryMonitor(logger)
}
@@ -712,7 +713,7 @@ func (mm *TaskMemoryMonitor) checkForMemoryIssues(stats TaskMemoryStats) {
// Check for goroutine leaks (arbitrary threshold)
if stats.Goroutines > 100 {
mm.logger.Infof("High goroutine count detected: %d", stats.Goroutines)
mm.logger.Debugf("High goroutine count detected: %d", stats.Goroutines)
}
// Check for heap growth without corresponding GC activity
+11 -2
View File
@@ -20,8 +20,9 @@ var (
cacheManagerInitOnce sync.Once
)
// GetGlobalCacheManager returns a singleton CacheManager instance
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
// GetGlobalCacheManager returns a singleton CacheManager instance.
//
// Deprecated: Use GetGlobalCacheManagerWithConfig instead.
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
return GetGlobalCacheManagerWithConfig(wg, nil)
}
@@ -104,6 +105,14 @@ func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
}
// GetSharedSessionInvalidationCache returns the shared session invalidation cache
// for backchannel and front-channel logout (IdP-initiated logout)
func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
}
// Close gracefully shuts down all cache components
func (cm *CacheManager) Close() error {
cm.mu.Lock()
+2 -2
View File
@@ -7,7 +7,7 @@ import (
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
// MarshalJSON implements custom JSON marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalJSON() ([]byte, error) {
// Build a map manually to avoid type alias issues with yaegi
@@ -47,7 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
return json.Marshal(result)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
// MarshalYAML implements custom YAML marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalYAML() (interface{}, error) {
// Build a map manually to avoid type alias issues with yaegi
+67
View File
@@ -90,6 +90,7 @@
<a href="#configuration" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Configuration</a>
<a href="#deployment" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Deployment</a>
<a href="#security" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Security</a>
<a href="#logout" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Logout</a>
</div>
<div class="flex items-center space-x-4">
<button id="theme-toggle" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 p-2 min-w-[44px] min-h-[44px] flex items-center justify-center" aria-label="Toggle theme">
@@ -114,6 +115,7 @@
<a href="#configuration" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Configuration</a>
<a href="#deployment" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Deployment</a>
<a href="#security" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Security</a>
<a href="#logout" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Logout</a>
</div>
</div>
</nav>
@@ -1219,6 +1221,71 @@ spec:
</div>
</section>
<!-- IdP-Initiated Logout Section -->
<section id="logout" class="py-12 sm:py-16 md:py-20 bg-white dark:bg-gray-900 theme-transition">
<div class="max-w-6xl mx-auto px-4 sm:px-6">
<div class="text-center mb-8 sm:mb-12">
<h2 class="text-2xl sm:text-3xl md:text-4xl font-bold text-gray-900 dark:text-gray-100 mb-3 sm:mb-4">IdP-Initiated Logout</h2>
<p class="text-base sm:text-lg text-gray-600 dark:text-gray-300 px-4">Support for OIDC Back-Channel and Front-Channel Logout specifications</p>
</div>
<div class="max-w-4xl mx-auto">
<div class="grid md:grid-cols-2 gap-6 mb-8">
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
<i class="fas fa-server mr-2 text-blue-500"></i>
Back-Channel Logout
</h3>
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
Server-to-server logout notification. The IdP sends a signed JWT (logout_token) directly to your application when a user logs out.
</p>
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
<li>&#8226; Signed JWT logout tokens</li>
<li>&#8226; Session ID (sid) based invalidation</li>
<li>&#8226; Subject (sub) based invalidation</li>
<li>&#8226; Works behind firewalls</li>
</ul>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
<i class="fas fa-browser mr-2 text-purple-500"></i>
Front-Channel Logout
</h3>
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
Browser-based logout via iframe. The IdP embeds an iframe pointing to your logout endpoint during user logout.
</p>
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
<li>&#8226; Iframe-based session termination</li>
<li>&#8226; Immediate cookie invalidation</li>
<li>&#8226; Simple GET request handling</li>
<li>&#8226; Issuer validation</li>
</ul>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Configuration Example</h3>
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
# ... other OIDC configuration ...
# Back-Channel Logout (server-to-server)
enableBackchannelLogout: true
backchannelLogoutURL: "/backchannel-logout"
# Front-Channel Logout (browser-based)
enableFrontchannelLogout: true
frontchannelLogoutURL: "/frontchannel-logout"</code></pre>
<p class="text-gray-600 dark:text-gray-400 text-sm mt-4">
Configure your IdP with the full URLs (e.g., <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">https://your-app.example.com/backchannel-logout</code>).
When a user logs out from the IdP, all their sessions across your applications will be invalidated.
</p>
</div>
</div>
</div>
</section>
<!-- Why Choose Section -->
<section class="py-12 sm:py-16 md:py-20 bg-gray-50 dark:bg-gray-800 theme-transition">
<div class="max-w-6xl mx-auto px-4 sm:px-6">
+1 -1
View File
@@ -954,7 +954,7 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
gd.mutex.RLock()
defer gd.mutex.RUnlock()
var degraded []string
degraded := make([]string, 0, len(gd.degradedServices))
for serviceName := range gd.degradedServices {
degraded = append(degraded, serviceName)
}
+1
View File
@@ -336,6 +336,7 @@ func createStringMap(keys []string) map[string]struct{} {
// and redirects to the provider's logout endpoint or configured post-logout URI.
// It handles potential errors during session retrieval or clearing.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing logout request")
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
+14 -27
View File
@@ -10,6 +10,14 @@ import (
"unicode/utf8"
)
// Pre-compiled regex patterns for validation (const patterns should use MustCompile)
var (
emailRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
urlRegexPattern = regexp.MustCompile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
tokenRegexPattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)
usernameRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
)
// InputValidator provides comprehensive input validation and sanitization
// to protect against common security vulnerabilities including SQL injection,
// XSS, path traversal, and other injection attacks. It validates and sanitizes
@@ -73,7 +81,7 @@ func DefaultInputValidationConfig() InputValidationConfig {
}
// NewInputValidator creates a new input validator with the specified configuration.
// It compiles all necessary regex patterns and initializes security pattern lists.
// It uses pre-compiled regex patterns and initializes security pattern lists.
//
// Parameters:
// - config: Validation configuration with size limits and mode settings.
@@ -81,29 +89,8 @@ func DefaultInputValidationConfig() InputValidationConfig {
//
// Returns:
// - A configured InputValidator instance.
// - An error if regex compilation fails.
// - An error (always nil, kept for API compatibility).
func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputValidator, error) {
// Compile regex patterns
emailRegex, err := regexp.Compile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if err != nil {
return nil, fmt.Errorf("failed to compile email regex: %w", err)
}
urlRegex, err := regexp.Compile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
if err != nil {
return nil, fmt.Errorf("failed to compile URL regex: %w", err)
}
tokenRegex, err := regexp.Compile(`^[A-Za-z0-9._-]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile token regex: %w", err)
}
usernameRegex, err := regexp.Compile(`^[a-zA-Z0-9._-]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile username regex: %w", err)
}
return &InputValidator{
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
@@ -112,10 +99,10 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
emailRegex: emailRegexPattern,
urlRegex: urlRegexPattern,
tokenRegex: tokenRegexPattern,
usernameRegex: usernameRegexPattern,
sqlInjectionPatterns: []string{
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
"union", "select", "insert", "update", "delete", "drop",
+9 -5
View File
@@ -241,9 +241,11 @@ func (s *cacheShard) evictLRULocked() bool {
element := s.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
s.deleteItemLocked(item)
return true
item, ok := element.Value.(*memoryCacheItem)
if ok {
s.deleteItemLocked(item)
return true
}
}
return false
}
@@ -267,8 +269,10 @@ func (s *cacheShard) getOldestAccessTime() time.Time {
element := s.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
return item.accessedAt
item, ok := element.Value.(*memoryCacheItem)
if ok {
return item.accessedAt
}
}
return time.Time{}
}
+2 -2
View File
@@ -345,7 +345,7 @@ func (r *RedisBackend) prefixKey(key string) string {
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
// It checks context cancellation at multiple points to ensure fast abort when the
// caller's context is cancelled (e.g., due to request timeout).
// caller's context is canceled (e.g., due to request timeout).
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
maxRetries := 3
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
@@ -377,7 +377,7 @@ func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*Red
err = operation(conn)
r.pool.Put(conn)
// Check context after operation - if cancelled, don't bother retrying
// Check context after operation - if canceled, don't bother retrying
if ctx.Err() != nil {
return ctx.Err()
}
+1 -1
View File
@@ -201,7 +201,7 @@ func TestConnectionPool_ContextCancellation(t *testing.T) {
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Try to get another with cancelled context
// Try to get another with canceled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
+2 -2
View File
@@ -44,7 +44,7 @@ type RESPWriter struct {
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
func NewRESPWriter(w io.Writer) *RESPWriter {
writer := writerPool.Get().(*RESPWriter)
writer, _ := writerPool.Get().(*RESPWriter)
writer.w = w
return writer
}
@@ -80,7 +80,7 @@ type RESPReader struct {
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
func NewRESPReader(r io.Reader) *RESPReader {
reader := readerPool.Get().(*RESPReader)
reader, _ := readerPool.Get().(*RESPReader)
reader.r.Reset(r)
return reader
}
+1 -1
View File
@@ -87,7 +87,7 @@ func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher
// If successful, store in cache
if call.err == nil && call.val != nil {
// Use a background context for cache storage to ensure it completes
// even if the original context is cancelled
// even if the original context is canceled
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
cancel()
+1 -1
View File
@@ -232,7 +232,7 @@ func (m *Manager) Close() error {
var firstErr error
if err := m.tokenCache.Close(); err != nil && firstErr == nil {
if err := m.tokenCache.Close(); err != nil {
firstErr = err
}
if err := m.metadataCache.Close(); err != nil && firstErr == nil {
+1 -1
View File
@@ -397,7 +397,7 @@ func (wp *WorkerPool) Submit(task func()) error {
}
// worker is the main worker routine
func (wp *WorkerPool) worker(id int) {
func (wp *WorkerPool) worker(_ int) {
defer wp.workerWg.Done()
for {
+1 -1
View File
@@ -173,7 +173,7 @@ func (m *FeatureManager) LoadFromEnv() {
for name, flag := range flags {
envVar := "FEATURE_" + name
if value := os.Getenv(envVar); value != "" {
enabled := strings.ToLower(value) == "true" || value == "1"
enabled := strings.EqualFold(value, "true") || value == "1"
flag.enabled.Store(enabled)
}
}
+1 -1
View File
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
var filteredScopes []string
for _, scope := range scopes {
if strings.ToLower(scope) != ScopeOfflineAccess {
if !strings.EqualFold(scope, ScopeOfflineAccess) {
filteredScopes = append(filteredScopes, scope)
}
}
+11 -10
View File
@@ -18,16 +18,17 @@ func GetProviderWarnings(providerType ProviderType) []ProviderWarning {
switch providerType {
case ProviderTypeGitHub:
warnings = append(warnings, ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "warning",
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
})
warnings = append(warnings, ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "info",
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
})
warnings = append(warnings,
ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "warning",
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
},
ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "info",
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
})
case ProviderTypeAuth0:
warnings = append(warnings, ProviderWarning{
+4 -3
View File
@@ -116,7 +116,7 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
// Continue to next attempt
case <-ctx.Done():
re.RecordFailure()
return fmt.Errorf("retry cancelled: %w", ctx.Err())
return fmt.Errorf("retry canceled: %w", ctx.Err())
}
}
@@ -301,7 +301,7 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
}
}
allMetrics["summary"] = map[string]interface{}{
summary := map[string]interface{}{
"totalMechanisms": len(rm.mechanisms),
"totalRequests": totalRequests,
"totalSuccesses": totalSuccesses,
@@ -310,8 +310,9 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
if totalRequests > 0 {
successRate := float64(totalSuccesses) / float64(totalRequests) * 100
allMetrics["summary"].(map[string]interface{})["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
summary["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
}
allMetrics["summary"] = summary
return allMetrics
}
+3 -3
View File
@@ -223,7 +223,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelled(t *testing.T) {
wg.Wait()
if execErr == nil {
t.Error("Expected error when context is cancelled")
t.Error("Expected error when context is canceled")
}
}
@@ -240,7 +240,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelledBeforeStart(t *testing
})
if err == nil {
t.Error("Expected error when context is already cancelled")
t.Error("Expected error when context is already canceled")
}
}
@@ -282,7 +282,7 @@ func TestRetryExecutor_isRetryableError(t *testing.T) {
{name: "timeout", err: errors.New("TIMEOUT"), expected: true}, // case insensitive
{name: "EOF", err: errors.New("EOF"), expected: false},
{name: "random error", err: errors.New("something else"), expected: false},
{name: "context cancelled", err: context.Canceled, expected: false},
{name: "context canceled", err: context.Canceled, expected: false},
{name: "context deadline exceeded", err: context.DeadlineExceeded, expected: false},
}
+3 -3
View File
@@ -213,9 +213,9 @@ func (jwk *JWK) ToECDSAPublicKey() (*ecdsa.PublicKey, error) {
// GetKey finds a key by its ID (kid) in the JWKSet.
// Returns nil if no key with the given ID is found.
func (jwks *JWKSet) GetKey(kid string) *JWK {
for _, key := range jwks.Keys {
if key.Kid == kid {
return &key
for i := range jwks.Keys {
if jwks.Keys[i].Kid == kid {
return &jwks.Keys[i]
}
}
return nil
+1 -1
View File
@@ -120,7 +120,7 @@ func getReplayCacheStats() (size int, maxSize int) {
// Parameters:
// - ctx: Parent context for cancellation
// - logger: Logger for debug output (can be nil)
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
func startReplayCacheCleanup(_ context.Context, logger *Logger) {
registry := GetGlobalTaskRegistry()
// Define the cleanup task function
+502
View File
@@ -0,0 +1,502 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik.
// This file implements OIDC Backchannel Logout (OpenID Connect Back-Channel Logout 1.0)
// and Front-Channel Logout (OpenID Connect Front-Channel Logout 1.0) functionality.
package traefikoidc
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
const (
// logoutTokenType is the expected typ claim for logout tokens
// #nosec G101 -- This is a JWT type claim value from OIDC spec, not a credential
logoutTokenType = "logout+jwt"
// sessionInvalidationTTL is how long to remember invalidated sessions
// Should be at least as long as your session max age
sessionInvalidationTTL = 25 * time.Hour
)
// LogoutTokenClaims represents the claims in an OIDC logout token
// as defined in OpenID Connect Back-Channel Logout 1.0
type LogoutTokenClaims struct {
Issuer string `json:"iss"`
Subject string `json:"sub,omitempty"`
Audience interface{} `json:"aud"` // Can be string or []string
IssuedAt int64 `json:"iat"`
JTI string `json:"jti"`
Events map[string]interface{} `json:"events"`
SessionID string `json:"sid,omitempty"`
Nonce string `json:"nonce,omitempty"` // Must NOT be present
}
// handleBackchannelLogout processes OIDC Backchannel Logout requests.
// It accepts POST requests with a logout_token parameter containing a JWT
// that identifies which session(s) to terminate.
//
// According to OpenID Connect Back-Channel Logout 1.0:
// - The logout_token is a JWT signed by the IdP
// - It contains either a 'sid' (session ID) or 'sub' (subject) claim to identify the session
// - The RP must validate the token and invalidate the matching session(s)
//
// Parameters:
// - rw: The HTTP response writer
// - req: The HTTP request containing the logout_token
func (t *TraefikOidc) handleBackchannelLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing backchannel logout request")
// Backchannel logout must be POST
if req.Method != http.MethodPost {
t.logger.Errorf("Backchannel logout: invalid method %s, expected POST", req.Method)
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse form data to get logout_token
if err := req.ParseForm(); err != nil {
t.logger.Errorf("Backchannel logout: failed to parse form: %v", err)
http.Error(rw, "Bad request", http.StatusBadRequest)
return
}
logoutToken := req.FormValue("logout_token")
if logoutToken == "" {
// Also try reading from request body as raw JWT
body, err := io.ReadAll(io.LimitReader(req.Body, 64*1024)) // 64KB limit
if err == nil && len(body) > 0 {
logoutToken = string(body)
}
}
if logoutToken == "" {
t.logger.Error("Backchannel logout: missing logout_token")
http.Error(rw, "logout_token required", http.StatusBadRequest)
return
}
// Parse and validate the logout token
claims, err := t.validateLogoutToken(logoutToken)
if err != nil {
t.logger.Errorf("Backchannel logout: token validation failed: %v", err)
// Return 400 for invalid token per spec
http.Error(rw, "Invalid logout token", http.StatusBadRequest)
return
}
// Invalidate session(s) based on sid or sub
if err := t.invalidateSession(claims.SessionID, claims.Subject); err != nil {
t.logger.Errorf("Backchannel logout: failed to invalidate session: %v", err)
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
return
}
t.logger.Infof("Backchannel logout: successfully invalidated session (sid=%s, sub=%s)",
claims.SessionID, claims.Subject)
// Return 200 OK with empty body per spec
rw.WriteHeader(http.StatusOK)
}
// handleFrontchannelLogout processes OIDC Front-Channel Logout requests.
// It accepts GET requests with 'iss' and 'sid' query parameters that identify
// which session to terminate. The IdP typically loads this URL in an iframe.
//
// According to OpenID Connect Front-Channel Logout 1.0:
// - The request contains 'iss' (issuer) and optionally 'sid' (session ID)
// - The RP should clear the session and return a response (typically empty or image)
// - The response must be cacheable to allow the IdP to load it in an iframe
//
// Parameters:
// - rw: The HTTP response writer
// - req: The HTTP request containing iss and sid parameters
func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing front-channel logout request")
// Front-channel logout should be GET
if req.Method != http.MethodGet {
t.logger.Errorf("Front-channel logout: invalid method %s, expected GET", req.Method)
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get iss and sid from query parameters
iss := req.URL.Query().Get("iss")
sid := req.URL.Query().Get("sid")
// Validate issuer matches our expected issuer
t.metadataMu.RLock()
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
if iss != "" && iss != expectedIssuer {
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
return
}
// Must have at least sid for front-channel logout
if sid == "" {
t.logger.Error("Front-channel logout: missing sid parameter")
http.Error(rw, "sid parameter required", http.StatusBadRequest)
return
}
// Invalidate the session
if err := t.invalidateSession(sid, ""); err != nil {
t.logger.Errorf("Front-channel logout: failed to invalidate session: %v", err)
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
return
}
t.logger.Infof("Front-channel logout: successfully invalidated session (sid=%s)", sid)
// Return a minimal HTML response that's suitable for iframe loading
// Set headers to allow embedding and caching
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Cache-Control", "no-cache, no-store")
rw.Header().Set("Pragma", "no-cache")
// Allow embedding in iframes from any origin (required for front-channel logout)
rw.Header().Del("X-Frame-Options")
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("<!DOCTYPE html><html><head><title>Logged Out</title></head><body></body></html>"))
}
// validateLogoutToken parses and validates a logout token JWT.
// It verifies the token signature, issuer, audience, and required claims.
//
// Parameters:
// - tokenString: The raw JWT logout token
//
// Returns:
// - The parsed logout token claims
// - An error if validation fails
func (t *TraefikOidc) validateLogoutToken(tokenString string) (*LogoutTokenClaims, error) {
// Parse the JWT
jwt, err := parseJWT(tokenString)
if err != nil {
return nil, fmt.Errorf("failed to parse logout token: %w", err)
}
// Check token type if present
if typ, ok := jwt.Header["typ"].(string); ok {
// The typ should be "logout+jwt" or omitted
if typ != "" && typ != logoutTokenType && typ != "JWT" {
return nil, fmt.Errorf("invalid token type: %s", typ)
}
}
// Verify signature only (not standard claims - logout tokens don't have 'exp')
if err := t.verifyLogoutTokenSignature(jwt, tokenString); err != nil {
return nil, fmt.Errorf("signature verification failed: %w", err)
}
// Extract claims
claims := &LogoutTokenClaims{}
claimsJSON, err := json.Marshal(jwt.Claims)
if err != nil {
return nil, fmt.Errorf("failed to marshal claims: %w", err)
}
if err := json.Unmarshal(claimsJSON, claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
// Validate required claims
t.metadataMu.RLock()
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
// Validate issuer
if claims.Issuer != expectedIssuer {
return nil, fmt.Errorf("issuer mismatch: got %s, expected %s", claims.Issuer, expectedIssuer)
}
// Validate audience (must contain our client_id)
if !t.validateLogoutTokenAudience(claims.Audience) {
return nil, fmt.Errorf("audience validation failed")
}
// Validate iat (issued at) - must be present and not too old
if claims.IssuedAt == 0 {
return nil, fmt.Errorf("missing iat claim")
}
iatTime := time.Unix(claims.IssuedAt, 0)
// Allow up to 5 minutes clock skew and 10 minutes token age
if time.Since(iatTime) > 15*time.Minute {
return nil, fmt.Errorf("logout token too old: issued at %v", iatTime)
}
// Token should not be from the future (with 5 min clock skew tolerance)
if iatTime.After(time.Now().Add(5 * time.Minute)) {
return nil, fmt.Errorf("logout token issued in the future: %v", iatTime)
}
// Validate events claim - must contain the logout event
if claims.Events == nil {
return nil, fmt.Errorf("missing events claim")
}
if _, ok := claims.Events["http://schemas.openid.net/event/backchannel-logout"]; !ok {
return nil, fmt.Errorf("missing backchannel-logout event in events claim")
}
// Validate that nonce is NOT present (per spec)
if claims.Nonce != "" {
return nil, fmt.Errorf("nonce claim must not be present in logout token")
}
// Must have either sid or sub (or both)
if claims.SessionID == "" && claims.Subject == "" {
return nil, fmt.Errorf("logout token must contain either sid or sub claim")
}
return claims, nil
}
// validateLogoutTokenAudience checks if the logout token audience contains our client_id
func (t *TraefikOidc) validateLogoutTokenAudience(aud interface{}) bool {
switch v := aud.(type) {
case string:
return v == t.clientID
case []interface{}:
for _, a := range v {
if s, ok := a.(string); ok && s == t.clientID {
return true
}
}
case []string:
for _, a := range v {
if a == t.clientID {
return true
}
}
}
return false
}
// verifyLogoutTokenSignature verifies only the signature of a logout token.
// Unlike VerifyJWTSignatureAndClaims, this does NOT validate standard claims like 'exp'
// because logout tokens don't have an expiration claim per OIDC Back-Channel Logout spec.
//
// Parameters:
// - jwt: The parsed JWT structure
// - tokenString: The raw token string for signature verification
//
// Returns:
// - An error if signature verification fails
func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) error {
t.logger.Debug("Verifying logout token signature")
// Read jwksURL with RLock
t.metadataMu.RLock()
jwksURL := t.jwksURL
t.metadataMu.RUnlock()
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
if jwks == nil {
return fmt.Errorf("JWKS is nil, cannot verify token")
}
kid, ok := jwt.Header["kid"].(string)
if !ok || kid == "" {
return fmt.Errorf("missing key ID in token header")
}
alg, ok := jwt.Header["alg"].(string)
if !ok || alg == "" {
return fmt.Errorf("missing algorithm in token header")
}
// Find the matching key in JWKS
var matchingKey *JWK
for i := range jwks.Keys {
if jwks.Keys[i].Kid == kid {
matchingKey = &jwks.Keys[i]
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
if err := verifySignature(tokenString, publicKeyPEM, alg); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Logout token signature verified successfully")
return nil
}
// invalidateSession marks a session as invalidated in the session invalidation cache.
// It stores entries by both sid and sub if available.
//
// Parameters:
// - sid: The session ID to invalidate (from the 'sid' claim)
// - sub: The subject to invalidate (from the 'sub' claim)
//
// Returns:
// - An error if the invalidation fails
func (t *TraefikOidc) invalidateSession(sid, sub string) error {
if t.sessionInvalidationCache == nil {
return fmt.Errorf("session invalidation cache not initialized")
}
now := time.Now().Unix()
// Store by session ID
if sid != "" {
key := t.buildSessionInvalidationKey("sid", sid)
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
t.logger.Debugf("Invalidated session by sid: %s", sid)
}
// Store by subject (invalidates all sessions for this user)
if sub != "" {
key := t.buildSessionInvalidationKey("sub", sub)
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
t.logger.Debugf("Invalidated session by sub: %s", sub)
}
return nil
}
// isSessionInvalidated checks if a session has been invalidated via backchannel
// or front-channel logout.
//
// Parameters:
// - sid: The session ID to check
// - sub: The subject to check
// - sessionCreatedAt: When the session was created (to compare against invalidation time)
//
// Returns:
// - true if the session has been invalidated, false otherwise
func (t *TraefikOidc) isSessionInvalidated(sid, sub string, sessionCreatedAt time.Time) bool {
if t.sessionInvalidationCache == nil {
return false
}
// Truncate session creation time to seconds for fair comparison with Unix timestamps
sessionCreatedAtSec := sessionCreatedAt.Truncate(time.Second)
// Check by session ID first (more specific)
if sid != "" {
key := t.buildSessionInvalidationKey("sid", sid)
if val, found := t.sessionInvalidationCache.Get(key); found {
if invalidatedAt, ok := val.(int64); ok {
// Session was invalidated at or after it was created
invalidationTime := time.Unix(invalidatedAt, 0)
if !invalidationTime.Before(sessionCreatedAtSec) {
t.logger.Debugf("Session invalidated by sid: %s", sid)
return true
}
}
}
}
// Check by subject (all sessions for this user)
if sub != "" {
key := t.buildSessionInvalidationKey("sub", sub)
if val, found := t.sessionInvalidationCache.Get(key); found {
if invalidatedAt, ok := val.(int64); ok {
// Sessions for this subject created at or before invalidation are invalid
invalidationTime := time.Unix(invalidatedAt, 0)
if !invalidationTime.Before(sessionCreatedAtSec) {
t.logger.Debugf("Session invalidated by sub: %s", sub)
return true
}
}
}
}
return false
}
// buildSessionInvalidationKey creates a cache key for session invalidation
func (t *TraefikOidc) buildSessionInvalidationKey(keyType, value string) string {
return fmt.Sprintf("session_invalidation:%s:%s", keyType, value)
}
// extractSessionInfo extracts sid and sub from an ID token for session tracking
func (t *TraefikOidc) extractSessionInfo(idToken string) (sid, sub string, createdAt time.Time) {
if idToken == "" {
return "", "", time.Time{}
}
jwt, err := parseJWT(idToken)
if err != nil {
return "", "", time.Time{}
}
// Extract sid (session ID)
if sidVal, ok := jwt.Claims["sid"].(string); ok {
sid = sidVal
}
// Extract sub (subject)
if subVal, ok := jwt.Claims["sub"].(string); ok {
sub = subVal
}
// Extract iat for session creation time
if iatVal, ok := jwt.Claims["iat"].(float64); ok {
createdAt = time.Unix(int64(iatVal), 0)
} else {
// Default to now if iat not present
createdAt = time.Now()
}
return sid, sub, createdAt
}
// determineLogoutPath checks if the given path matches any logout URL
func (t *TraefikOidc) determineLogoutPath(path string) string {
// Check backchannel logout path
if t.backchannelLogoutPath != "" && path == t.backchannelLogoutPath {
return "backchannel"
}
// Check front-channel logout path
if t.frontchannelLogoutPath != "" && path == t.frontchannelLogoutPath {
return "frontchannel"
}
// Check regular logout path (for RP-initiated logout)
if path == t.logoutURLPath {
return "rp"
}
return ""
}
// normalizeLogoutPath ensures logout paths start with / and prevents open redirects
func normalizeLogoutPath(path string) string {
if path == "" {
return ""
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// Prevent open redirect: ensure second character is not / or \
// This prevents URLs like //example.com or /\example.com from being treated as absolute URLs
if len(path) > 1 && (path[1] == '/' || path[1] == '\\') {
// Strip leading slashes/backslashes and re-normalize
path = strings.TrimLeft(path, "/\\")
if path != "" {
path = "/" + path
}
}
return path
}
+1623
View File
File diff suppressed because it is too large Load Diff
+15 -10
View File
@@ -212,16 +212,21 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}
return 60 * time.Second
}(),
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
enableBackchannelLogout: config.EnableBackchannelLogout,
enableFrontchannelLogout: config.EnableFrontchannelLogout,
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
}
// Log audience configuration
+1 -1
View File
@@ -9,7 +9,7 @@ import (
"gopkg.in/yaml.v3"
)
// Config Marshalling Tests
// Config Marshaling Tests
func TestConfig_MarshalJSON(t *testing.T) {
config := &Config{
+1 -1
View File
@@ -229,7 +229,7 @@ func (mm *MemoryMonitor) updateGoroutineTracking(stats *MemoryStats) {
}
// Check for potential goroutine leak
if stats.NumGoroutines > mm.baselineGoroutines+int(mm.alertThresholds.GoroutineCount) {
if stats.NumGoroutines > mm.baselineGoroutines+mm.alertThresholds.GoroutineCount {
mm.mu.Lock()
wasAlert := mm.goroutineLeakAlert
if !wasAlert {
+62 -6
View File
@@ -26,6 +26,31 @@ import (
// - rw: The HTTP response writer.
// - req: The incoming HTTP request.
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Log request entry for debugging routing issues
t.logger.Debugf("Incoming request: %s %s", req.Method, req.URL.Path)
// Handle logout requests early - before waiting for OIDC initialization
// This allows users to logout even if the OIDC provider is unavailable
if req.URL.Path == t.logoutURLPath {
t.logger.Debugf("Logout path matched early: %s", req.URL.Path)
t.handleLogout(rw, req)
return
}
// Handle backchannel logout (IdP-initiated POST with logout_token)
if t.enableBackchannelLogout && t.backchannelLogoutPath != "" && req.URL.Path == t.backchannelLogoutPath {
t.logger.Debug("Backchannel logout path matched")
t.handleBackchannelLogout(rw, req)
return
}
// Handle front-channel logout (IdP-initiated GET with sid/iss in iframe)
if t.enableFrontchannelLogout && t.frontchannelLogoutPath != "" && req.URL.Path == t.frontchannelLogoutPath {
t.logger.Debug("Front-channel logout path matched")
t.handleFrontchannelLogout(rw, req)
return
}
if !strings.HasPrefix(req.URL.Path, "/health") {
t.firstRequestMutex.Lock()
if !t.firstRequestReceived {
@@ -42,6 +67,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.firstRequestMutex.Unlock()
}
// Check excluded URLs before waiting for initialization
if t.determineExcludedURL(req.URL.Path) {
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
t.next.ServeHTTP(rw, req)
return
}
// Check for SSE requests before waiting for initialization
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/event-stream") {
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
t.next.ServeHTTP(rw, req)
return
}
// Log waiting for initialization to help diagnose hanging requests
t.logger.Debug("Waiting for OIDC provider initialization...")
select {
case <-t.initComplete:
// Read issuerURL with RLock
@@ -83,7 +126,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.next.ServeHTTP(rw, req)
return
}
acceptHeader := req.Header.Get("Accept")
acceptHeader = req.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/event-stream") {
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
// Set forwarded user headers from existing session before bypassing
@@ -100,7 +143,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.next.ServeHTTP(rw, req)
return
}
t.sessionManager.CleanupOldCookies(rw, req)
session, err := t.sessionManager.GetSession(req)
@@ -131,10 +173,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
return
@@ -275,6 +313,24 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
return
}
// Check if session has been invalidated via backchannel or front-channel logout
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
idToken := session.GetIDToken()
if idToken != "" {
sid, sub, createdAt := t.extractSessionInfo(idToken)
if t.isSessionInvalidated(sid, sub, createdAt) {
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
// Clear the session and redirect to login
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing invalidated session: %v", err)
}
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
}
tokenForClaims := session.GetIDToken()
if tokenForClaims == "" {
tokenForClaims = session.GetAccessToken()
+32
View File
@@ -95,6 +95,38 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
}
}
// TestLogoutWorksWithoutOIDCInitialization tests that logout works even if OIDC provider is unavailable
// This is critical for allowing users to clear their session when the provider is down
func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
logoutURLPath: "/logout",
postLogoutRedirectURI: "/",
forceHTTPS: false,
}
// Note: initComplete is NOT closed, simulating OIDC provider being unavailable
req := httptest.NewRequest("GET", "/logout", nil)
req.Host = "example.com"
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Should redirect to post-logout URI even without OIDC initialization
if rw.Code != http.StatusFound {
t.Errorf("Expected redirect (302) for logout, got %d", rw.Code)
}
location := rw.Header().Get("Location")
if location == "" {
t.Error("Expected Location header for logout redirect")
}
}
// TestMiddlewareDomainRestrictions tests domain-based access control
// NOTE: Currently commented out due to complex session setup requirements
// These scenarios are tested indirectly through integration tests
+2 -2
View File
@@ -234,7 +234,7 @@ func (rc *RefreshCoordinator) CoordinateRefresh(
// Returns (operation, false, nil) if joined an existing operation
// Returns (nil, false, error) if the operation was rejected
func (rc *RefreshCoordinator) getOrCreateOperation(
ctx context.Context,
_ context.Context,
sessionID string,
tokenHash string,
refreshToken string,
@@ -293,7 +293,7 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
// executeRefreshAsync performs the actual refresh operation asynchronously
func (rc *RefreshCoordinator) executeRefreshAsync(
operation *refreshOperation,
sessionID string,
_ string, // sessionID - reserved for future metrics/logging
tokenHash string,
refreshFunc func() (*TokenResponse, error),
) {
+3 -14
View File
@@ -164,7 +164,7 @@ func decompressCombinedPayload(compressed string) (*combinedSessionPayload, erro
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gr.Close()
defer func() { _ = gr.Close() }()
// Limit decompressed size to prevent zip bombs
limitedReader := io.LimitReader(gr, 512*1024) // 512KB max
@@ -1588,7 +1588,7 @@ func (sd *SessionData) returnToPoolSafely() {
// Parameters:
// - r: The HTTP request context.
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
func (sd *SessionData) clearTokenChunks(_ *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
clearSessionValues(session, true)
}
@@ -1820,23 +1820,12 @@ func (sd *SessionData) SetAccessToken(token string) {
defer sd.sessionMutex.Unlock()
if token != "" {
dotCount := strings.Count(token, ".")
// Reject tokens with exactly 1 dot (invalid format - neither JWT nor opaque)
if dotCount == 1 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Invalid token format during storage (dots: %d) - rejecting", dotCount)
}
return
}
// For opaque tokens (no dots), ensure minimum length for security
if dotCount == 0 && len(token) < 20 {
if len(token) < 20 {
if sd.manager != nil && sd.manager.logger != nil {
sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token))
}
return
}
// Tokens with 2 dots are JWTs, tokens with 0 dots are opaque
// Both are valid formats
}
currentAccessToken := sd.getAccessTokenUnsafe()
+2
View File
@@ -926,6 +926,8 @@ func (cm *ChunkManager) detectRepeatedCharacters(token string, config TokenConfi
//
// Returns:
// - An error if the token is expired or has invalid expiration, nil if valid.
//
//nolint:unparam // error return kept for API consistency and future use
func (cm *ChunkManager) validateTokenExpiration(token string, config TokenConfig) error {
if !strings.Contains(token, ".") {
return nil
+5 -114
View File
@@ -65,6 +65,10 @@ type Config struct {
ForceHTTPS bool `json:"forceHTTPS"`
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
EnableBackchannelLogout bool `json:"enableBackchannelLogout,omitempty"`
EnableFrontchannelLogout bool `json:"enableFrontchannelLogout,omitempty"`
BackchannelLogoutURL string `json:"backchannelLogoutURL,omitempty"`
FrontchannelLogoutURL string `json:"frontchannelLogoutURL,omitempty"`
}
// RedisConfig configures Redis cache backend settings for distributed caching.
@@ -730,6 +734,7 @@ func (l *Logger) Errorf(format string, args ...interface{}) {
}
// newNoOpLogger creates a logger that discards all output.
//
// Deprecated: Use GetSingletonNoOpLogger() instead for better memory efficiency.
func newNoOpLogger() *Logger {
return GetSingletonNoOpLogger()
@@ -744,15 +749,6 @@ func newNoOpLogger() *Logger {
// - code: The HTTP status code for the response.
// - logger: The Logger instance to use for logging the error.
//
// handleError writes an HTTP error response with the specified status code and message.
// It logs the error and sets appropriate headers before writing the response.
//
//lint:ignore U1000 Kept for potential future error handling
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error("%s", message)
http.Error(w, message, code)
}
// GetSecurityHeadersApplier returns a function that applies security headers
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
@@ -1058,111 +1054,6 @@ func (rc *RedisConfig) ApplyEnvFallbacks() {
}
}
// LoadRedisConfigFromEnv loads Redis configuration from environment variables.
// Deprecated: Use RedisConfig.ApplyEnvFallbacks() on an existing config instead.
// This function is kept for backward compatibility but should not be used directly.
func LoadRedisConfigFromEnv() *RedisConfig {
// Check if Redis is enabled
enabledStr := os.Getenv("REDIS_ENABLED")
if enabledStr == "" || enabledStr == "false" || enabledStr == "0" {
return nil
}
config := &RedisConfig{
Enabled: true,
}
// Parse numeric values
if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
if db, err := strconv.Atoi(dbStr); err == nil {
config.DB = db
}
}
if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
if poolSize, err := strconv.Atoi(poolSizeStr); err == nil {
config.PoolSize = poolSize
}
}
if connectTimeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); connectTimeoutStr != "" {
if timeout, err := strconv.Atoi(connectTimeoutStr); err == nil {
config.ConnectTimeout = timeout
}
}
if readTimeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); readTimeoutStr != "" {
if timeout, err := strconv.Atoi(readTimeoutStr); err == nil {
config.ReadTimeout = timeout
}
}
if writeTimeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeoutStr != "" {
if timeout, err := strconv.Atoi(writeTimeoutStr); err == nil {
config.WriteTimeout = timeout
}
}
// Parse boolean values
if enableTLSStr := os.Getenv("REDIS_ENABLE_TLS"); enableTLSStr == "true" || enableTLSStr == "1" {
config.EnableTLS = true
}
if skipVerifyStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipVerifyStr == "true" || skipVerifyStr == "1" {
config.TLSSkipVerify = true
}
// Parse hybrid mode settings
if l1SizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); l1SizeStr != "" {
if size, err := strconv.Atoi(l1SizeStr); err == nil {
config.HybridL1Size = size
}
}
if l1MemoryStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); l1MemoryStr != "" {
if memory, err := strconv.ParseInt(l1MemoryStr, 10, 64); err == nil {
config.HybridL1MemoryMB = memory
}
}
// Parse circuit breaker settings
if enableCBStr := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCBStr == "false" || enableCBStr == "0" {
config.EnableCircuitBreaker = false
} else {
config.EnableCircuitBreaker = true // Default to enabled
}
if cbThresholdStr := os.Getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD"); cbThresholdStr != "" {
if threshold, err := strconv.Atoi(cbThresholdStr); err == nil {
config.CircuitBreakerThreshold = threshold
}
}
if cbTimeoutStr := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeoutStr != "" {
if timeout, err := strconv.Atoi(cbTimeoutStr); err == nil {
config.CircuitBreakerTimeout = timeout
}
}
// Parse health check settings
if enableHCStr := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHCStr == "false" || enableHCStr == "0" {
config.EnableHealthCheck = false
} else {
config.EnableHealthCheck = true // Default to enabled
}
if hcIntervalStr := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcIntervalStr != "" {
if interval, err := strconv.Atoi(hcIntervalStr); err == nil {
config.HealthCheckInterval = interval
}
}
// Apply defaults after loading from env
config.ApplyDefaults()
return config
}
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if origin == allowed || allowed == "*" {
+51 -5
View File
@@ -4,7 +4,10 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"sync"
"sync/atomic"
@@ -251,6 +254,30 @@ func TestSingletonResourceManager(t *testing.T) {
})
}
// createMockOIDCServer creates a mock OIDC server for testing
func createMockOIDCServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/authorize",
"token_endpoint": "https://example.com/token",
"jwks_uri": "https://example.com/jwks",
"userinfo_endpoint": "https://example.com/userinfo",
})
case "/jwks":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"keys": []interface{}{},
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
}
// TestContextAwareGoroutineManagement tests context-aware goroutine management
func TestContextAwareGoroutineManagement(t *testing.T) {
t.Run("GoroutineCleanupOnContextCancel", func(t *testing.T) {
@@ -259,13 +286,17 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC server
mockServer := createMockOIDCServer()
defer mockServer.Close()
initialGoroutines := runtime.NumGoroutine()
ctx, cancel := context.WithCancel(context.Background())
// Create a TraefikOidc instance with context
config := &Config{
ProviderURL: "https://example.com",
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
}
@@ -308,12 +339,20 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC servers
mockServer1 := createMockOIDCServer()
defer mockServer1.Close()
mockServer2 := createMockOIDCServer()
defer mockServer2.Close()
mockServer3 := createMockOIDCServer()
defer mockServer3.Close()
initialGoroutines := runtime.NumGoroutine()
configs := []Config{
{ProviderURL: "https://example1.com", ClientID: "client1", ClientSecret: "secret1"},
{ProviderURL: "https://example2.com", ClientID: "client2", ClientSecret: "secret2"},
{ProviderURL: "https://example3.com", ClientID: "client3", ClientSecret: "secret3"},
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
}
var plugins []*TraefikOidc
@@ -366,6 +405,13 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC servers
mockServers := make([]*httptest.Server, 3)
for i := 0; i < 3; i++ {
mockServers[i] = createMockOIDCServer()
defer mockServers[i].Close()
}
rm := GetResourceManager()
// Register singleton cleanup task
@@ -386,7 +432,7 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
for i := 0; i < 3; i++ {
ctx := context.Background()
config := &Config{
ProviderURL: fmt.Sprintf("https://example%d.com", i),
ProviderURL: mockServers[i].URL,
ClientID: fmt.Sprintf("client%d", i),
ClientSecret: fmt.Sprintf("secret%d", i),
}
+5 -5
View File
@@ -22,7 +22,7 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
// Test helper adapters for the new test files
// resetGlobalState resets all global singletons to prevent test interference
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func resetGlobalState() {
// Reset global task registry first to stop all background tasks
@@ -137,7 +137,7 @@ func (tc *testCleanup) cleanupAll() {
}
// createTestConfig creates a config with all required fields populated for testing
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func createTestConfig() *Config {
config := CreateConfig()
@@ -151,7 +151,7 @@ func createTestConfig() *Config {
*/
// setupTestOIDCMiddleware creates a test OIDC middleware instance with mock servers
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httptest.Server) {
// Reset global state to ensure test isolation
@@ -339,7 +339,7 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
*/
// createMockJWT creates a mock JWT token for testing - adapter for existing tests
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func createMockJWT(t *testing.T, sub, email string) string {
return ValidIDToken
@@ -361,7 +361,7 @@ func createTestSession() *SessionData {
}
// injectSessionIntoRequest saves the session and adds the resulting cookies to the request
// nolint:unused // Kept for potential future use in integration tests
//nolint:unused // Kept for potential future use in integration tests
/*
func injectSessionIntoRequest(t *testing.T, req *http.Request, session *SessionData) {
// Create a response recorder to capture cookies
-1
View File
@@ -454,7 +454,6 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
if err != nil {
errMsg := err.Error()
//nolint:gocritic // Complex error handling with provider-specific conditions
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
t.logger.Debug("Refresh token expired or revoked: %v", err)
// Clear all tokens and authentication state when refresh token is invalid
+5
View File
@@ -119,6 +119,8 @@ type TraefikOidc struct {
clientID string
clientSecret string
registrationURL string
backchannelLogoutPath string
frontchannelLogoutPath string
scopesSupported []string
scopes []string
refreshGracePeriod time.Duration
@@ -126,7 +128,10 @@ type TraefikOidc struct {
shutdownOnce sync.Once
metadataRetryMutex sync.Mutex
firstRequestMutex sync.Mutex
sessionInvalidationCache CacheInterface
minimalHeaders bool
enableBackchannelLogout bool
enableFrontchannelLogout bool
firstRequestReceived bool
requireTokenIntrospection bool
metadataRefreshStarted bool
+75 -20
View File
@@ -21,6 +21,10 @@ const (
CacheTypeJWK CacheType = "jwk"
CacheTypeSession CacheType = "session"
CacheTypeGeneral CacheType = "general"
// maxCacheEntrySize defines the maximum size for a single cache entry (64 MiB)
// This prevents integer overflow when allocating memory for serialization
maxCacheEntrySize = 64 * 1024 * 1024
)
// UniversalCacheConfig provides configuration for the universal cache
@@ -720,22 +724,6 @@ func (c *UniversalCache) SetWithMetadata(key string, value interface{}, ttl time
return nil
}
// GetTyped retrieves a typed value from the cache
func GetTyped[T any](c *UniversalCache, key string) (T, bool) {
var zero T
value, exists := c.Get(key)
if !exists {
return zero, false
}
typed, ok := value.(T)
if !ok {
return zero, false
}
return typed, true
}
// TokenCacheOperations provides token-specific operations
func (c *UniversalCache) BlacklistToken(token string, ttl time.Duration) error {
if c.config.Type != CacheTypeToken {
@@ -784,14 +772,81 @@ func (c *UniversalCache) Strategy() CacheStrategy {
// serialize converts a value to bytes for backend storage
func (c *UniversalCache) serialize(value interface{}) ([]byte, error) {
// Use JSON for serialization - simple and universal
return json.Marshal(value)
// If value is already a byte slice (e.g., pre-marshaled JSON from metadata_cache),
// store it directly with a marker to prevent double-encoding.
// This fixes the issue where []byte was being JSON-marshaled, causing Base64 encoding.
if bytes, ok := value.([]byte); ok {
// Validate size to prevent integer overflow
if len(bytes) > maxCacheEntrySize {
return nil, fmt.Errorf("cache entry size %d exceeds maximum allowed size %d", len(bytes), maxCacheEntrySize)
}
// Check for potential overflow when adding marker byte
if len(bytes) == maxCacheEntrySize {
return nil, fmt.Errorf("cache entry size would overflow when adding marker byte")
}
// Prepend marker byte 0x00 to indicate raw bytes (not JSON-encoded)
result := make([]byte, len(bytes)+1)
result[0] = 0x00
copy(result[1:], bytes)
return result, nil
}
// For all other types (maps, strings, etc.), use JSON encoding
// Prepend marker byte 0x01 to indicate JSON-encoded data
jsonData, err := json.Marshal(value)
if err != nil {
return nil, err
}
// Validate size to prevent integer overflow
if len(jsonData) > maxCacheEntrySize {
return nil, fmt.Errorf("serialized cache entry size %d exceeds maximum allowed size %d", len(jsonData), maxCacheEntrySize)
}
// Check for potential overflow when adding marker byte
if len(jsonData) == maxCacheEntrySize {
return nil, fmt.Errorf("serialized cache entry size would overflow when adding marker byte")
}
result := make([]byte, len(jsonData)+1)
result[0] = 0x01
copy(result[1:], jsonData)
return result, nil
}
// deserialize converts bytes from backend storage to a value
func (c *UniversalCache) deserialize(data []byte, value interface{}) error {
// Use JSON for deserialization
return json.Unmarshal(data, value)
if len(data) == 0 {
return fmt.Errorf("cannot deserialize empty data")
}
// Check for type marker (added by serialize)
if data[0] == 0x00 {
// Raw bytes - strip marker and return as-is
rawBytes := data[1:]
if ptr, ok := value.(*interface{}); ok {
*ptr = rawBytes
return nil
}
return fmt.Errorf("cannot deserialize raw bytes into %T", value)
}
if data[0] == 0x01 {
// JSON-encoded - strip marker and unmarshal
return json.Unmarshal(data[1:], value)
}
// Legacy data without marker (for backward compatibility)
// Try to unmarshal as JSON
if err := json.Unmarshal(data, value); err != nil {
// If unmarshal fails, treat as raw bytes
if ptr, ok := value.(*interface{}); ok {
*ptr = data
return nil
}
return err
}
return nil
}
// prefixKey adds a cache type prefix to the key for backend storage
+517
View File
@@ -0,0 +1,517 @@
package traefikoidc
import (
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestUniversalCache_SerializeDeserialize tests the fix for issue #116
// where metadata was stored as Base64-encoded JSON but read as plain JSON
func TestUniversalCache_SerializeDeserialize(t *testing.T) {
t.Parallel()
t.Run("RawBytesPreserved", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Test data: pre-marshaled JSON bytes (like metadata_cache uses)
testData := []byte(`{"issuer":"https://example.com","jwks_uri":"https://example.com/jwks"}`)
// Serialize
serialized, err := cache.serialize(testData)
require.NoError(t, err)
assert.NotNil(t, serialized)
// Should have marker byte
assert.Equal(t, byte(0x00), serialized[0], "Should have raw bytes marker")
assert.Equal(t, testData, serialized[1:], "Data should be preserved after marker")
// Deserialize
var result interface{}
err = cache.deserialize(serialized, &result)
require.NoError(t, err)
// Should get back []byte
resultBytes, ok := result.([]byte)
require.True(t, ok, "Result should be []byte")
assert.Equal(t, testData, resultBytes, "Deserialized data should match original")
})
t.Run("JSONEncodedTypes", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
testCases := []struct {
name string
value interface{}
}{
{
name: "Map",
value: map[string]interface{}{"key": "value", "number": 42.0},
},
{
name: "String",
value: "test-string",
},
{
name: "Number",
value: 123.456,
},
{
name: "Array",
value: []interface{}{"a", "b", "c"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Serialize
serialized, err := cache.serialize(tc.value)
require.NoError(t, err)
assert.NotNil(t, serialized)
// Should have JSON marker byte
assert.Equal(t, byte(0x01), serialized[0], "Should have JSON marker")
// Verify the JSON portion is valid
var checkJSON interface{}
err = json.Unmarshal(serialized[1:], &checkJSON)
require.NoError(t, err, "Should be valid JSON after marker")
// Deserialize
var result interface{}
err = cache.deserialize(serialized, &result)
require.NoError(t, err)
// Compare results (using JSON round-trip for consistent comparison)
expectedJSON, _ := json.Marshal(tc.value)
resultJSON, _ := json.Marshal(result)
assert.JSONEq(t, string(expectedJSON), string(resultJSON), "Deserialized data should match original")
})
}
})
t.Run("LegacyDataCompatibility", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Simulate legacy data (JSON without marker byte)
legacyData := []byte(`{"legacy":"data"}`)
var result interface{}
err := cache.deserialize(legacyData, &result)
require.NoError(t, err)
// Should successfully unmarshal as JSON
resultMap, ok := result.(map[string]interface{})
require.True(t, ok, "Should unmarshal legacy JSON data")
assert.Equal(t, "data", resultMap["legacy"])
})
t.Run("EmptyDataHandling", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
var result interface{}
err := cache.deserialize([]byte{}, &result)
assert.Error(t, err, "Should error on empty data")
assert.Contains(t, err.Error(), "empty data")
})
t.Run("OverflowProtection_LargeBytes", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Create a byte slice that exceeds maxCacheEntrySize (64 MiB)
oversizedBytes := make([]byte, 65*1024*1024) // 65 MiB
// Attempt to serialize - should fail with overflow error
_, err := cache.serialize(oversizedBytes)
require.Error(t, err, "Should error on oversized byte slice")
assert.Contains(t, err.Error(), "exceeds maximum allowed size")
})
t.Run("OverflowProtection_ExactMaxSize", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Create a byte slice exactly at maxCacheEntrySize
// This should fail because adding marker byte would overflow
exactMaxBytes := make([]byte, 64*1024*1024) // Exactly 64 MiB
_, err := cache.serialize(exactMaxBytes)
require.Error(t, err, "Should error when adding marker would overflow")
assert.Contains(t, err.Error(), "would overflow when adding marker byte")
})
t.Run("OverflowProtection_SafeSize", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Create a byte slice well within limits
safeBytes := make([]byte, 1024*1024) // 1 MiB - safe size
serialized, err := cache.serialize(safeBytes)
require.NoError(t, err, "Should succeed with safe size")
assert.NotNil(t, serialized)
assert.Equal(t, len(safeBytes)+1, len(serialized), "Should add marker byte")
})
t.Run("OverflowProtection_JSONData", func(t *testing.T) {
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
})
defer cache.Close()
// Create a very large map that will exceed limits when JSON-encoded
largeMap := make(map[string]string)
// Each entry is roughly 50 bytes, so we need ~1.3M entries to exceed 64 MiB
for i := 0; i < 1400000; i++ {
key := fmt.Sprintf("key_%d", i)
largeMap[key] = "value_with_some_content_to_make_it_larger"
}
_, err := cache.serialize(largeMap)
require.Error(t, err, "Should error when JSON serialization exceeds size limit")
assert.Contains(t, err.Error(), "exceeds maximum allowed size")
})
}
// TestUniversalCache_RedisIntegration_Issue116 tests the complete fix for issue #116
// with actual Redis backend to ensure metadata cache works correctly
func TestUniversalCache_RedisIntegration_Issue116(t *testing.T) {
t.Parallel()
// Start miniredis server
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
// Create Redis backend
redisConfig := backends.DefaultRedisConfig(mr.Addr())
redisConfig.RedisPrefix = "test:"
backend, err := backends.NewRedisBackend(redisConfig)
require.NoError(t, err)
defer backend.Close()
t.Run("MetadataCache_StoreAndRetrieve", func(t *testing.T) {
// Create cache with Redis backend
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeMetadata,
MaxSize: 100,
}, backend)
defer cache.Close()
// Simulate metadata_cache.Set behavior:
// 1. Marshal metadata to JSON
metadata := ProviderMetadata{
Issuer: "https://example.com",
JWKSURL: "https://example.com/jwks",
TokenURL: "https://example.com/token",
AuthURL: "https://example.com/authorize",
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
// 2. Store the JSON bytes
key := "v2:https://example.com"
err = cache.Set(key, jsonData, 1*time.Hour)
require.NoError(t, err)
// 3. Retrieve the data
retrieved, exists := cache.Get(key)
require.True(t, exists, "Data should exist in cache")
// 4. Should get back []byte (not a string or map)
retrievedBytes, ok := retrieved.([]byte)
require.True(t, ok, "Retrieved value should be []byte, got %T", retrieved)
// 5. Should be able to unmarshal as JSON
var retrievedMetadata ProviderMetadata
err = json.Unmarshal(retrievedBytes, &retrievedMetadata)
require.NoError(t, err, "Should be able to unmarshal retrieved bytes as JSON")
// 6. Verify data integrity
assert.Equal(t, metadata.Issuer, retrievedMetadata.Issuer)
assert.Equal(t, metadata.JWKSURL, retrievedMetadata.JWKSURL)
assert.Equal(t, metadata.TokenURL, retrievedMetadata.TokenURL)
})
t.Run("MetadataCache_NoBase64Encoding", func(t *testing.T) {
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeMetadata,
MaxSize: 100,
}, backend)
defer cache.Close()
// Store JSON bytes
jsonData := []byte(`{"issuer":"https://test.com"}`)
key := "v2:https://test.com"
err = cache.Set(key, jsonData, 1*time.Hour)
require.NoError(t, err)
// Retrieve
retrieved, exists := cache.Get(key)
require.True(t, exists)
retrievedBytes, ok := retrieved.([]byte)
require.True(t, ok)
// The retrieved data should NOT start with "eyJ" (Base64 encoding of "{")
// This was the bug in issue #116
assert.NotEqual(t, []byte("eyJ"), retrievedBytes[:3], "Data should not be Base64 encoded")
// Should be valid JSON
var checkJSON map[string]interface{}
err = json.Unmarshal(retrievedBytes, &checkJSON)
require.NoError(t, err, "Data should be valid JSON")
assert.Equal(t, "https://test.com", checkJSON["issuer"])
})
t.Run("TokenCache_MapValues", func(t *testing.T) {
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 100,
}, backend)
defer cache.Close()
// Store a map (like TokenCache does)
claims := map[string]interface{}{
"sub": "user123",
"exp": 1234567890.0,
"scope": "read write",
}
key := "token:abc123"
err = cache.Set(key, claims, 10*time.Minute)
require.NoError(t, err)
// Retrieve
retrieved, exists := cache.Get(key)
require.True(t, exists)
// Should get back a map
retrievedMap, ok := retrieved.(map[string]interface{})
require.True(t, ok, "Retrieved value should be map[string]interface{}")
assert.Equal(t, "user123", retrievedMap["sub"])
assert.Equal(t, 1234567890.0, retrievedMap["exp"])
})
t.Run("MixedTypes_SameCache", func(t *testing.T) {
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
}, backend)
defer cache.Close()
// Store different types
jsonBytes := []byte(`{"type":"json-bytes"}`)
err = cache.Set("key1", jsonBytes, 1*time.Hour)
require.NoError(t, err)
mapData := map[string]interface{}{"type": "map"}
err = cache.Set("key2", mapData, 1*time.Hour)
require.NoError(t, err)
stringData := "plain-string"
err = cache.Set("key3", stringData, 1*time.Hour)
require.NoError(t, err)
// Retrieve and verify each type
val1, exists := cache.Get("key1")
require.True(t, exists)
bytes1, ok := val1.([]byte)
require.True(t, ok)
assert.Equal(t, jsonBytes, bytes1)
val2, exists := cache.Get("key2")
require.True(t, exists)
map2, ok := val2.(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "map", map2["type"])
val3, exists := cache.Get("key3")
require.True(t, exists)
str3, ok := val3.(string)
require.True(t, ok)
assert.Equal(t, stringData, str3)
})
}
// TestUniversalCache_BackwardCompatibility tests that old cached data is handled gracefully
func TestUniversalCache_BackwardCompatibility(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisConfig := backends.DefaultRedisConfig(mr.Addr())
backend, err := backends.NewRedisBackend(redisConfig)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("LegacyJSONData", func(t *testing.T) {
// Manually insert legacy data (plain JSON without marker)
legacyKey := "general:legacy-key"
legacyData := []byte(`{"old":"format"}`)
err = backend.Set(ctx, legacyKey, legacyData, 1*time.Hour)
require.NoError(t, err)
// Try to retrieve via UniversalCache
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
}, backend)
defer cache.Close()
retrieved, exists := cache.Get("legacy-key")
require.True(t, exists, "Should retrieve legacy data")
// Should deserialize as JSON map
retrievedMap, ok := retrieved.(map[string]interface{})
require.True(t, ok, "Should unmarshal legacy JSON")
assert.Equal(t, "format", retrievedMap["old"])
})
t.Run("LegacyCorruptData", func(t *testing.T) {
// Insert corrupt/invalid data
corruptKey := "general:corrupt-key"
corruptData := []byte("not json and no marker")
err = backend.Set(ctx, corruptKey, corruptData, 1*time.Hour)
require.NoError(t, err)
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
}, backend)
defer cache.Close()
retrieved, exists := cache.Get("corrupt-key")
require.True(t, exists)
// Should return as raw bytes (fallback)
retrievedBytes, ok := retrieved.([]byte)
require.True(t, ok, "Should return corrupt data as raw bytes")
assert.Equal(t, corruptData, retrievedBytes)
})
}
// TestMetadataCache_Issue116_Regression is the main regression test for issue #116
// This specifically tests the scenario described in the GitHub issue
func TestMetadataCache_Issue116_Regression(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
// Create Redis backend
redisConfig := backends.DefaultRedisConfig(mr.Addr())
redisConfig.RedisPrefix = "traefik:"
backend, err := backends.NewRedisBackend(redisConfig)
require.NoError(t, err)
defer backend.Close()
// Create a simple logger
logger := GetSingletonNoOpLogger()
// Create metadata cache instance
metadataCache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeMetadata,
MaxSize: 100,
Logger: logger,
SkipAutoCleanup: true,
}, backend)
defer metadataCache.Close()
// Use the actual MetadataCache wrapper
wg := &sync.WaitGroup{}
mc := &MetadataCache{
cache: metadataCache,
logger: logger,
wg: wg,
}
// Test: Store and retrieve metadata (the scenario from issue #116)
providerURL := "https://example.com"
metadata := &ProviderMetadata{
Issuer: "https://example.com",
AuthURL: "https://example.com/authorize",
TokenURL: "https://example.com/token",
JWKSURL: "https://example.com/jwks",
RevokeURL: "https://example.com/revoke",
EndSessionURL: "https://example.com/logout",
RegistrationURL: "https://example.com/register",
ScopesSupported: []string{"openid", "profile", "email"},
}
// Store metadata
err = mc.Set(providerURL, metadata, 1*time.Hour)
require.NoError(t, err, "Should store metadata without error")
// Retrieve metadata
retrieved, exists := mc.Get(providerURL)
require.True(t, exists, "Should retrieve stored metadata")
require.NotNil(t, retrieved, "Retrieved metadata should not be nil")
// Verify no corruption - this was failing in issue #116 with "invalid character 'e'" error
assert.Equal(t, metadata.Issuer, retrieved.Issuer)
assert.Equal(t, metadata.AuthURL, retrieved.AuthURL)
assert.Equal(t, metadata.TokenURL, retrieved.TokenURL)
assert.Equal(t, metadata.JWKSURL, retrieved.JWKSURL)
// Verify the data is not Base64-encoded in Redis
// This checks the root cause mentioned in the issue
ctx := context.Background()
rawData, _, exists, err := backend.Get(ctx, "metadata:v2:"+providerURL)
require.NoError(t, err)
require.True(t, exists)
// Strip the marker byte
require.Greater(t, len(rawData), 1, "Data should have marker byte")
dataWithoutMarker := rawData[1:]
// Should not start with "eyJ" (Base64 encoding of "{")
if len(dataWithoutMarker) >= 3 {
assert.NotEqual(t, "eyJ", string(dataWithoutMarker[:3]), "Data should not be Base64-encoded")
}
// Should be valid JSON
var checkMetadata ProviderMetadata
err = json.Unmarshal(dataWithoutMarker, &checkMetadata)
require.NoError(t, err, "Stored data should be valid JSON, not Base64")
assert.Equal(t, metadata.Issuer, checkMetadata.Issuer)
}
+48 -52
View File
@@ -13,21 +13,22 @@ import (
// It runs a single consolidated cleanup goroutine for all caches, reducing
// goroutine count and CPU overhead compared to per-cache cleanup routines.
type UniversalCacheManager struct {
sharedBackend backends.CacheBackend
ctx context.Context
tokenTypeCache *UniversalCache
jwkCache *UniversalCache
sessionCache *UniversalCache
introspectionCache *UniversalCache
tokenCache *UniversalCache
metadataCache *UniversalCache
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
logger *Logger
blacklistCache *UniversalCache
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
cleanupStarted bool
sharedBackend backends.CacheBackend
ctx context.Context
tokenTypeCache *UniversalCache
jwkCache *UniversalCache
sessionCache *UniversalCache
introspectionCache *UniversalCache
tokenCache *UniversalCache
metadataCache *UniversalCache
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
logger *Logger
blacklistCache *UniversalCache
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
cleanupStarted bool
}
var (
@@ -170,6 +171,16 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
})
// Initialize session invalidation cache for backchannel/front-channel logout
// This cache stores invalidated session IDs and subjects to revoke sessions
manager.sessionInvalidationCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeSession,
MaxSize: 5000, // Support many concurrent invalidations
DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
})
}
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
@@ -363,6 +374,19 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
createBackend("dcr"),
)
// Session invalidation cache - CRITICAL for distributed backchannel/front-channel logout
// Uses Redis backend to share session invalidations across all Traefik replicas
manager.sessionInvalidationCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeSession,
MaxSize: 5000, // Support many concurrent invalidations
DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
},
createBackend("session_invalidation"),
)
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
}
@@ -411,6 +435,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
m.introspectionCache,
m.tokenTypeCache,
m.dcrCredentialsCache,
m.sessionInvalidationCache,
}
m.mu.RUnlock()
@@ -452,13 +477,6 @@ func (m *UniversalCacheManager) GetJWKCache() *UniversalCache {
return m.jwkCache
}
// GetSessionCache returns the session cache
func (m *UniversalCacheManager) GetSessionCache() *UniversalCache {
m.mu.RLock()
defer m.mu.RUnlock()
return m.sessionCache
}
// GetIntrospectionCache returns the token introspection cache
func (m *UniversalCacheManager) GetIntrospectionCache() *UniversalCache {
m.mu.RLock()
@@ -473,6 +491,13 @@ func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache {
return m.tokenTypeCache
}
// GetSessionInvalidationCache returns the session invalidation cache for backchannel/front-channel logout
func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
m.mu.RLock()
defer m.mu.RUnlock()
return m.sessionInvalidationCache
}
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
m.mu.RLock()
@@ -495,7 +520,7 @@ func (m *UniversalCacheManager) Close() error {
// Close all caches first (they won't close the shared backend)
for _, cache := range []*UniversalCache{
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache,
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache,
} {
if cache != nil {
_ = cache.Close() // Safe to ignore: best effort cache cleanup
@@ -516,35 +541,6 @@ func (m *UniversalCacheManager) Close() error {
return nil
}
// InitializeCacheManagerFromConfig initializes the cache manager with configuration
// This should be called early in the application startup with the loaded configuration
func InitializeCacheManagerFromConfig(config *Config) *UniversalCacheManager {
logger := NewLogger(config.LogLevel)
// Initialize Redis config if not present
if config.Redis == nil {
config.Redis = &RedisConfig{}
}
// Apply environment variable fallbacks for fields not set in config
// This allows env vars to be used as optional overrides only when
// the config field is not explicitly set through Traefik
config.Redis.ApplyEnvFallbacks()
// Apply defaults after env fallbacks
config.Redis.ApplyDefaults()
// Log cache backend selection
if config.Redis != nil && config.Redis.Enabled {
logger.Infof("Initializing cache backend with Redis: mode=%s, address=%s",
config.Redis.CacheMode, config.Redis.Address)
} else {
logger.Info("Initializing cache backend with memory-only mode")
}
return GetUniversalCacheManagerWithConfig(logger, config.Redis)
}
// ResetUniversalCacheManagerForTesting resets the singleton for testing purposes only
// This should only be called in test code to ensure proper cleanup between tests
func ResetUniversalCacheManagerForTesting() {