mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bfd702a447 | |||
| 68c150eba4 | |||
| 9cbca4c4fb | |||
| 684a990f59 | |||
| 1b6c8616fd | |||
| 4d28fa01ab | |||
| 2d1b04c637 | |||
| ccbb98b9dd | |||
| 1362cc0dac | |||
| 249dcad1b3 | |||
| de4b4d7258 | |||
| 9d52f1b018 | |||
| 57724918fe | |||
| 775de2ada1 | |||
| 7816e05c98 | |||
| 8bf7998150 | |||
| 22c4323fcb | |||
| 06b219d1f8 | |||
| 413e4a1b7d | |||
| 69e0d98c67 | |||
| 6d893df12b |
@@ -11,7 +11,9 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
docker/
|
||||
.claude/*.out
|
||||
*.test
|
||||
.leann/
|
||||
|
||||
+49
-32
@@ -14,21 +14,22 @@ linters:
|
||||
- gosec
|
||||
- misspell
|
||||
- noctx
|
||||
- nolintlint
|
||||
- prealloc
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- whitespace
|
||||
disable:
|
||||
- exhaustive
|
||||
- funlen
|
||||
- gocognit
|
||||
- gocyclo # Disabled: OAuth/OIDC flows are inherently complex
|
||||
- goprintffuncname # Disabled: naming convention is project-specific
|
||||
- lll
|
||||
- mnd
|
||||
- testpackage
|
||||
- whitespace # Disabled: style preference about newlines
|
||||
- wsl
|
||||
settings:
|
||||
dupl:
|
||||
@@ -47,29 +48,13 @@ linters:
|
||||
- fmt.Fprintln
|
||||
goconst:
|
||||
min-len: 3
|
||||
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
|
||||
min-occurrences: 15 # Increased to reduce noise for standard OAuth2/OIDC strings and common patterns like "true"
|
||||
ignore-tests: true
|
||||
gocritic:
|
||||
# Using default enabled checks in v2
|
||||
enabled-checks:
|
||||
- appendCombine
|
||||
- boolExprSimplify
|
||||
- builtinShadow
|
||||
- commentedOutCode
|
||||
- emptyFallthrough
|
||||
- equalFold
|
||||
- hexLiteral
|
||||
- indexAlloc
|
||||
- initClause
|
||||
- methodExprCall
|
||||
- nestingReduce
|
||||
- rangeExprCopy
|
||||
- rangeValCopy
|
||||
- stringXbytes
|
||||
- typeAssertChain
|
||||
- typeUnparen
|
||||
- unlabelStmt
|
||||
- yodaStyleExpr
|
||||
# Disable style-only checks that add noise
|
||||
disabled-checks:
|
||||
- ifElseChain # Style preference, switch not always clearer
|
||||
- elseif # Style preference
|
||||
gocyclo:
|
||||
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
|
||||
gosec:
|
||||
@@ -106,23 +91,23 @@ linters:
|
||||
- name: error-return
|
||||
- name: error-strings
|
||||
- name: error-naming
|
||||
- name: exported
|
||||
- name: if-return
|
||||
# - name: exported # Disabled: too noisy, not all exported functions need comments
|
||||
# - name: if-return # Disabled: style preference
|
||||
- name: increment-decrement
|
||||
- name: var-naming
|
||||
- name: var-declaration
|
||||
- name: package-comments
|
||||
# - name: var-naming # Disabled: too strict for legacy code (IP vs Ip)
|
||||
# - name: var-declaration # Disabled: explicit zero values can be clearer
|
||||
# - name: package-comments # Disabled: handled by other tools
|
||||
- name: range
|
||||
- name: receiver-naming
|
||||
- name: time-naming
|
||||
- name: unexported-return
|
||||
- name: indent-error-flow
|
||||
# - name: indent-error-flow # Disabled: style preference
|
||||
- name: errorf
|
||||
- name: empty-block
|
||||
# - name: empty-block # Disabled: sometimes empty blocks are intentional
|
||||
- name: superfluous-else
|
||||
- name: unused-parameter
|
||||
# - name: unused-parameter # Disabled: test callbacks and interface implementations often have required unused params
|
||||
- name: unreachable-code
|
||||
- name: redefines-builtin-id
|
||||
# - name: redefines-builtin-id # Disabled: min/max helpers are common before Go 1.21
|
||||
unparam:
|
||||
check-exported: false
|
||||
staticcheck:
|
||||
@@ -132,8 +117,15 @@ linters:
|
||||
- -QF1003 # Tagged switch - style preference, may affect Yaegi
|
||||
- -QF1007 # Merge conditional assignment - style preference
|
||||
- -QF1008 # Remove embedded field - may break Yaegi compatibility
|
||||
- -QF1011 # Omit type from declaration - style preference
|
||||
- -QF1012 # Use fmt.Fprintf - style preference
|
||||
- -SA9003 # Empty branch - sometimes intentional for future work
|
||||
- -ST1000 # Package comment format - not required for all packages
|
||||
- -ST1003 # Package name format - allowed for test packages
|
||||
- -ST1016 # Receiver name consistency - legacy code
|
||||
- -ST1020 # Comment format for methods - style preference
|
||||
- -ST1021 # Comment format for types - style preference
|
||||
- -ST1023 # Omit type from declaration - style preference
|
||||
exclusions:
|
||||
generated: lax
|
||||
rules:
|
||||
@@ -144,18 +136,43 @@ linters:
|
||||
- goconst
|
||||
- gocyclo
|
||||
- gosec
|
||||
- govet
|
||||
- ineffassign
|
||||
- noctx
|
||||
- prealloc
|
||||
- unparam
|
||||
- revive
|
||||
- gocritic
|
||||
path: _test\.go
|
||||
- linters:
|
||||
- dupl
|
||||
- gocyclo
|
||||
- govet
|
||||
- noctx
|
||||
- prealloc
|
||||
- unparam
|
||||
- revive
|
||||
- gocritic
|
||||
path: test.*\.go
|
||||
- linters:
|
||||
- gocritic
|
||||
- unused
|
||||
- errcheck
|
||||
- revive
|
||||
path: mocks.*\.go
|
||||
- linters:
|
||||
- errcheck
|
||||
- revive
|
||||
- gocritic
|
||||
- govet
|
||||
- unparam
|
||||
path: internal/testutil/
|
||||
- linters:
|
||||
- govet
|
||||
- unparam
|
||||
- noctx
|
||||
- prealloc
|
||||
path: integration/
|
||||
- linters:
|
||||
- gosec
|
||||
text: 'G404:'
|
||||
|
||||
@@ -47,3 +47,14 @@ release:
|
||||
name_template: "v{{ .Version }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
|
||||
signs:
|
||||
- cmd: cosign
|
||||
signature: "${artifact}.sigstore.json"
|
||||
args:
|
||||
- sign-blob
|
||||
- "--bundle=${signature}"
|
||||
- "${artifact}"
|
||||
- "--yes"
|
||||
artifacts: checksum
|
||||
output: true
|
||||
|
||||
+47
-1610
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,49 @@
|
||||
# Security Fix: Integer Overflow Protection in Cache Serialization
|
||||
|
||||
## Summary
|
||||
|
||||
Fixed **High severity** integer overflow vulnerability identified by GitHub Advanced Security in PR #117.
|
||||
|
||||
## Vulnerability
|
||||
|
||||
**Locations**: `universal_cache.go` lines 789 and 811
|
||||
- `result := make([]byte, len(bytes)+1)` - Raw bytes path
|
||||
- `result := make([]byte, len(jsonData)+1)` - JSON encoding path
|
||||
|
||||
**Risk**: Potential integer overflow when allocating memory for very large cache entries.
|
||||
|
||||
## Fix Applied
|
||||
|
||||
1. **Added size limit constant**:
|
||||
```go
|
||||
maxCacheEntrySize = 64 * 1024 * 1024 // 64 MiB
|
||||
```
|
||||
|
||||
2. **Size validation before allocation**:
|
||||
- Validates entry size doesn't exceed limit
|
||||
- Validates adding marker byte won't overflow
|
||||
- Returns descriptive error messages
|
||||
|
||||
3. **Comprehensive test coverage**:
|
||||
- Oversized byte slices (>64 MiB)
|
||||
- Exact max size edge case
|
||||
- Safe sizes (normal operation)
|
||||
- Large JSON data structures
|
||||
|
||||
## Verification
|
||||
|
||||
✅ All tests pass with race detection
|
||||
✅ No security issues (golangci-lint, gosec)
|
||||
✅ 76.3% test coverage maintained
|
||||
|
||||
## Impact
|
||||
|
||||
- No breaking changes
|
||||
- Negligible performance overhead
|
||||
- Prevents potential buffer overflows
|
||||
- Predictable memory usage
|
||||
|
||||
---
|
||||
|
||||
**Date**: January 8, 2026
|
||||
**Severity**: High → Resolved
|
||||
+1
-1
@@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetUserIdentifier("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
|
||||
+60
-11
@@ -4,8 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
@@ -44,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
@@ -77,7 +76,12 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := uuid.NewString()
|
||||
csrfToken, err := newUUIDv4()
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to generate CSRF token: %v", err)
|
||||
http.Error(rw, "Failed to generate CSRF token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to generate nonce: %v", err)
|
||||
@@ -246,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
@@ -286,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens to prevent replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -334,9 +338,54 @@ func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
|
||||
strings.Contains(accept, "application/json")
|
||||
}
|
||||
|
||||
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
|
||||
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
||||
// This is a heuristic check - actual implementation would depend on
|
||||
// the specific provider and token metadata
|
||||
return false // Placeholder implementation
|
||||
// isNonNavigationRequest reports whether the request is a browser
|
||||
// sub-resource (script, image, stylesheet, fetch, serviceWorker) rather than
|
||||
// a top-level HTML navigation. Non-navigation requests MUST NOT trigger an
|
||||
// OIDC redirect flow: several sub-resource loads happening in parallel would
|
||||
// each call defaultInitiateAuthentication, each overwriting the session's
|
||||
// CSRF/nonce, breaking the eventual callback (issue #129).
|
||||
//
|
||||
// Detection prefers Sec-Fetch-Mode, which all modern browsers send
|
||||
// (Chrome/Edge/Firefox/Safari). For older or non-browser clients we fall
|
||||
// back to Accept: if Accept is present and does not list text/html, treat
|
||||
// it as a sub-resource. An empty/missing Accept is assumed to be navigation
|
||||
// (safer to redirect than 401 on an ambiguous request).
|
||||
func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
|
||||
if mode := req.Header.Get("Sec-Fetch-Mode"); mode != "" {
|
||||
return mode != "navigate"
|
||||
}
|
||||
accept := req.Header.Get("Accept")
|
||||
if accept == "" || accept == "*/*" {
|
||||
return false
|
||||
}
|
||||
return !strings.Contains(accept, "text/html")
|
||||
}
|
||||
|
||||
// isRefreshTokenExpired checks whether the stored refresh token is likely
|
||||
// past its useful lifetime, using the cookie-side issued_at timestamp set by
|
||||
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
|
||||
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
|
||||
// MaxRefreshTokenAgeSeconds; 0 disables the check).
|
||||
//
|
||||
// The point of this check is to short-circuit the refresh path BEFORE the
|
||||
// thundering herd hits the IdP for a token the provider has almost certainly
|
||||
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
|
||||
// style polling clients from looping on invalid_grant after a long pause.
|
||||
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
||||
if t == nil || session == nil {
|
||||
return false
|
||||
}
|
||||
if t.maxRefreshTokenAge <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
issuedAt := session.GetRefreshTokenIssuedAt()
|
||||
if issuedAt.IsZero() {
|
||||
// No timestamp recorded (legacy session pre-dating the issued_at
|
||||
// field). Don't force a re-auth - attempt refresh once and let the
|
||||
// IdP be the source of truth.
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(issuedAt) > t.maxRefreshTokenAge
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Pre-populate session with old data
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-access-token-with-many-characters")
|
||||
session.SetRefreshToken("old-refresh-token-with-many-characters")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Verify old data is cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
|
||||
// Verify new data is set
|
||||
s.Equal(csrfToken, session.GetCSRF())
|
||||
@@ -305,6 +305,90 @@ func (s *AuthFlowBehaviourSuite) TestIsAjaxRequest() {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsNonNavigationRequest verifies browser sub-resource detection used to
|
||||
// suppress OIDC redirects on parallel static-asset loads (issue #129).
|
||||
func (s *AuthFlowBehaviourSuite) TestIsNonNavigationRequest() {
|
||||
testCases := []struct {
|
||||
headers map[string]string
|
||||
name string
|
||||
expectNonNavigation bool
|
||||
}{
|
||||
{
|
||||
name: "Sec-Fetch-Mode navigate",
|
||||
headers: map[string]string{"Sec-Fetch-Mode": "navigate"},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode no-cors",
|
||||
headers: map[string]string{"Sec-Fetch-Mode": "no-cors"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode cors",
|
||||
headers: map[string]string{"Sec-Fetch-Mode": "cors"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode same-origin (fetch in page)",
|
||||
headers: map[string]string{"Sec-Fetch-Mode": "same-origin"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Accept text/html (fallback)",
|
||||
headers: map[string]string{"Accept": "text/html,application/xhtml+xml"},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "Accept image/png (fallback)",
|
||||
headers: map[string]string{"Accept": "image/png,image/*;q=0.8"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Accept application/javascript (fallback)",
|
||||
headers: map[string]string{"Accept": "application/javascript"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Accept */* treated as navigation",
|
||||
headers: map[string]string{"Accept": "*/*"},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "No Accept header assumed navigation",
|
||||
headers: map[string]string{},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode beats Accept (navigate wins)",
|
||||
headers: map[string]string{
|
||||
"Sec-Fetch-Mode": "navigate",
|
||||
"Accept": "application/javascript",
|
||||
},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode beats Accept (no-cors wins)",
|
||||
headers: map[string]string{
|
||||
"Sec-Fetch-Mode": "no-cors",
|
||||
"Accept": "text/html",
|
||||
},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/_static/asset.js", nil)
|
||||
for key, value := range tc.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
result := s.tOidc.isNonNavigationRequest(req)
|
||||
s.Equal(tc.expectNonNavigation, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCallback_MissingState tests callback with missing state parameter
|
||||
func (s *AuthFlowBehaviourSuite) TestHandleCallback_MissingState() {
|
||||
sessionManager, err := NewSessionManager(
|
||||
@@ -627,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
session, err := sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
session.mainSession.Values["redirect_count"] = 3
|
||||
|
||||
@@ -636,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
|
||||
// Session should be cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
s.Empty(session.GetIDToken())
|
||||
|
||||
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
|
||||
|
||||
+4
-3
@@ -599,8 +599,9 @@ func GetGlobalTaskMemoryMonitor(logger *Logger) *TaskMemoryMonitor {
|
||||
return globalTaskMemoryMonitor
|
||||
}
|
||||
|
||||
// NewTaskMemoryMonitor creates a new memory monitor for task registry
|
||||
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior
|
||||
// NewTaskMemoryMonitor creates a new memory monitor for task registry.
|
||||
//
|
||||
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior.
|
||||
func NewTaskMemoryMonitor(logger *Logger, registry *TaskRegistry) *TaskMemoryMonitor {
|
||||
return GetGlobalTaskMemoryMonitor(logger)
|
||||
}
|
||||
@@ -712,7 +713,7 @@ func (mm *TaskMemoryMonitor) checkForMemoryIssues(stats TaskMemoryStats) {
|
||||
|
||||
// Check for goroutine leaks (arbitrary threshold)
|
||||
if stats.Goroutines > 100 {
|
||||
mm.logger.Infof("High goroutine count detected: %d", stats.Goroutines)
|
||||
mm.logger.Debugf("High goroutine count detected: %d", stats.Goroutines)
|
||||
}
|
||||
|
||||
// Check for heap growth without corresponding GC activity
|
||||
|
||||
@@ -29,8 +29,9 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
assert.Equal(t, MemoryPressureNone, pressure)
|
||||
|
||||
// Collect stats to populate lastStats
|
||||
monitor.GetCurrentStats()
|
||||
// Explicitly sample to populate lastStats; GetCurrentStats is now a
|
||||
// cached read and no longer forces a runtime.ReadMemStats.
|
||||
monitor.Refresh()
|
||||
|
||||
// Now should return a valid pressure level
|
||||
pressure = monitor.GetMemoryPressure()
|
||||
@@ -46,11 +47,13 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Start monitoring should not panic
|
||||
// Start monitoring should not panic. Interval is clamped to the
|
||||
// minimum (30s); we rely on Refresh() when we need a synchronous
|
||||
// sample instead of waiting for a tick.
|
||||
assert.NotPanics(t, func() {
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 100*time.Millisecond)
|
||||
time.Sleep(GetTestDuration(50 * time.Millisecond))
|
||||
monitor.StartMonitoring(ctx, 0)
|
||||
monitor.Refresh()
|
||||
})
|
||||
|
||||
// Clean up
|
||||
@@ -117,6 +120,9 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Refresh forces a synchronous sample; GetCurrentStats is a cached
|
||||
// read, so we sample first to guarantee fresh data.
|
||||
monitor.Refresh()
|
||||
stats := monitor.GetCurrentStats()
|
||||
assert.NotNil(t, stats)
|
||||
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
|
||||
@@ -450,12 +456,12 @@ func TestMemoryMonitorIntegration(t *testing.T) {
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
defer monitor.StopMonitoring()
|
||||
|
||||
// Start monitoring
|
||||
// Start monitoring. The interval is clamped to the minimum (30s) so
|
||||
// the ticker won't fire during the test; drive the sample manually via
|
||||
// Refresh() instead.
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 50*time.Millisecond)
|
||||
|
||||
// Wait for at least one check
|
||||
time.Sleep(GetTestDuration(150 * time.Millisecond))
|
||||
monitor.StartMonitoring(ctx, 0)
|
||||
monitor.Refresh()
|
||||
|
||||
// Get pressure (should be a valid pressure level)
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
@@ -488,6 +494,7 @@ func TestMemoryStatsCollection(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
monitor.Refresh()
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
assert.NotNil(t, stats)
|
||||
@@ -501,6 +508,7 @@ func TestMemoryStatsCollection(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
monitor.Refresh()
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
// Should calculate and include pressure level
|
||||
@@ -521,13 +529,14 @@ func TestMemoryStatsCollection(t *testing.T) {
|
||||
// Allocate some memory
|
||||
_ = make([]byte, 1024*1024) // 1MB
|
||||
|
||||
// Get stats before GC
|
||||
beforeStats := monitor.GetCurrentStats()
|
||||
// Get stats before GC (explicit Refresh so we have a fresh pre-GC
|
||||
// snapshot to compare against, not the constructor baseline).
|
||||
beforeStats := monitor.Refresh()
|
||||
|
||||
// Trigger GC
|
||||
// Trigger GC (internally Refresh()es before and after)
|
||||
monitor.TriggerGC()
|
||||
|
||||
// Get stats after GC
|
||||
// Get stats after GC from cache (TriggerGC already refreshed it)
|
||||
afterStats := monitor.GetCurrentStats()
|
||||
|
||||
// After GC should have different stats
|
||||
|
||||
+137
@@ -0,0 +1,137 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// testCertPEM returns a valid PEM-encoded certificate harvested from an
|
||||
// httptest.NewTLSServer. Using httptest keeps the test free of any
|
||||
// handwritten static cert that could expire.
|
||||
func testCertPEM(t *testing.T) string {
|
||||
t.Helper()
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
cert := srv.Certificate()
|
||||
if cert == nil {
|
||||
t.Fatal("httptest.NewTLSServer did not expose a certificate")
|
||||
}
|
||||
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_Empty(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pool != nil {
|
||||
t.Errorf("expected nil pool when no CA source configured, got %v", pool)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_InlinePEM(t *testing.T) {
|
||||
cfg := &Config{CACertPEM: testCertPEM(t)}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Fatal("expected non-nil pool for valid CACertPEM")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_InlinePEM_Garbage(t *testing.T) {
|
||||
cfg := &Config{CACertPEM: "not a pem"}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for garbage CACertPEM, got nil")
|
||||
}
|
||||
if pool != nil {
|
||||
t.Errorf("expected nil pool on error, got %v", pool)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "caCertPEM") {
|
||||
t.Errorf("error should name the failing field, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_FilePath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ca.pem")
|
||||
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
|
||||
t.Fatalf("writing temp PEM: %v", err)
|
||||
}
|
||||
|
||||
cfg := &Config{CACertPath: path}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Fatal("expected non-nil pool for valid CACertPath")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_FilePath_Missing(t *testing.T) {
|
||||
cfg := &Config{CACertPath: "/does/not/exist/ca.pem"}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing CACertPath, got nil")
|
||||
}
|
||||
if pool != nil {
|
||||
t.Errorf("expected nil pool on error, got %v", pool)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_Combined(t *testing.T) {
|
||||
// Both inline and file sources populated — certificates from both should
|
||||
// be accepted into the same pool.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ca.pem")
|
||||
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
|
||||
t.Fatalf("writing temp PEM: %v", err)
|
||||
}
|
||||
|
||||
cfg := &Config{CACertPath: path, CACertPEM: testCertPEM(t)}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Fatal("expected non-nil pool when both sources set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedTransportPool_ConfigKeyDistinguishesCAAndSkipVerify(t *testing.T) {
|
||||
p := GetGlobalTransportPool()
|
||||
cfgSystem := DefaultHTTPClientConfig()
|
||||
|
||||
cfgSkip := DefaultHTTPClientConfig()
|
||||
cfgSkip.InsecureSkipVerify = true
|
||||
|
||||
cfgCustomCA := DefaultHTTPClientConfig()
|
||||
pool, err := (&Config{CACertPEM: testCertPEM(t)}).loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("loadCACertPool: %v", err)
|
||||
}
|
||||
cfgCustomCA.RootCAs = pool
|
||||
|
||||
keys := map[string]string{
|
||||
"system": p.configKey(cfgSystem),
|
||||
"skip": p.configKey(cfgSkip),
|
||||
"customCA": p.configKey(cfgCustomCA),
|
||||
}
|
||||
seen := make(map[string]string, len(keys))
|
||||
for name, key := range keys {
|
||||
if dup, ok := seen[key]; ok {
|
||||
t.Errorf("configKey collision: %s and %s share key %q", name, dup, key)
|
||||
}
|
||||
seen[key] = name
|
||||
}
|
||||
}
|
||||
+32
-8
@@ -20,8 +20,9 @@ var (
|
||||
cacheManagerInitOnce sync.Once
|
||||
)
|
||||
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance
|
||||
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance.
|
||||
//
|
||||
// Deprecated: Use GetGlobalCacheManagerWithConfig instead.
|
||||
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
return GetGlobalCacheManagerWithConfig(wg, nil)
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
|
||||
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedTokenCache returns the shared token cache
|
||||
@@ -93,7 +94,7 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
|
||||
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedTokenTypeCache returns the shared token type cache
|
||||
@@ -101,7 +102,23 @@ func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedSessionInvalidationCache returns the shared session invalidation cache
|
||||
// for backchannel and front-channel logout (IdP-initiated logout)
|
||||
func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
|
||||
// by the refresh path to coalesce grants across Traefik replicas via Redis.
|
||||
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
@@ -121,7 +138,8 @@ func CleanupGlobalCacheManager() error {
|
||||
|
||||
// CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface
|
||||
type CacheInterfaceWrapper struct {
|
||||
cache *UniversalCache
|
||||
cache *UniversalCache
|
||||
managed bool // If true, cache is managed globally and Close() is a no-op
|
||||
}
|
||||
|
||||
// Set stores a value
|
||||
@@ -149,9 +167,15 @@ func (c *CacheInterfaceWrapper) Cleanup() {
|
||||
c.cache.Cleanup()
|
||||
}
|
||||
|
||||
// Close shuts down the cache
|
||||
// Close shuts down the cache if it's not managed globally.
|
||||
// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding
|
||||
// when multiple plugin instances are closed during Traefik configuration reloads.
|
||||
func (c *CacheInterfaceWrapper) Close() {
|
||||
// Close the underlying cache to stop goroutines
|
||||
if c.managed {
|
||||
// Cache is managed globally by UniversalCacheManager, so we don't close it here.
|
||||
return
|
||||
}
|
||||
// Standalone cache - close it properly to stop cleanup goroutines
|
||||
if c.cache != nil {
|
||||
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
|
||||
}
|
||||
|
||||
+153
@@ -219,6 +219,159 @@ func TestCacheInterfaceWrapper_Close(t *testing.T) {
|
||||
nilWrapper.Close()
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_ManagedClose_Regression tests that managed cache wrappers
|
||||
// don't close the underlying cache when Close() is called. This is a regression test
|
||||
// for issue #105 where multiple plugin instances closing shared caches caused log flooding.
|
||||
func TestCacheInterfaceWrapper_ManagedClose_Regression(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Get a managed cache wrapper
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Verify it's marked as managed
|
||||
if !wrapper.managed {
|
||||
t.Error("Expected shared cache wrapper to be marked as managed")
|
||||
}
|
||||
|
||||
// Set some data before Close
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Close the wrapper (should be a no-op for managed caches)
|
||||
wrapper.Close()
|
||||
|
||||
// Verify the cache is still operational after Close
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected cache to still work after Close() on managed wrapper")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
|
||||
// Can still set new values
|
||||
cache.Set("new-key", "new-value", time.Hour)
|
||||
newValue, found := cache.Get("new-key")
|
||||
if !found || newValue != "new-value" {
|
||||
t.Error("Expected to be able to set new values after Close() on managed wrapper")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_StandaloneClose tests that standalone cache wrappers
|
||||
// properly close the underlying cache when Close() is called.
|
||||
func TestCacheInterfaceWrapper_StandaloneClose(t *testing.T) {
|
||||
// Create a standalone cache (not from the global cache manager)
|
||||
standaloneCache := NewCache()
|
||||
|
||||
wrapper, ok := standaloneCache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Verify it's NOT marked as managed
|
||||
if wrapper.managed {
|
||||
t.Error("Expected standalone cache wrapper to NOT be marked as managed")
|
||||
}
|
||||
|
||||
// Set some data
|
||||
standaloneCache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Get baseline goroutine count
|
||||
baselineGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Close the wrapper (should actually close the underlying cache)
|
||||
wrapper.Close()
|
||||
|
||||
// Give cleanup goroutine time to stop
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Goroutine count should decrease (cleanup routine stopped)
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > baselineGoroutines {
|
||||
// This is acceptable - other tests might have started goroutines
|
||||
t.Logf("Goroutine count: baseline=%d, final=%d", baselineGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_MultipleInstancesClose_Regression tests that multiple
|
||||
// plugin instances can close their cache wrappers without affecting shared caches.
|
||||
// This is a regression test for issue #105.
|
||||
func TestCacheInterfaceWrapper_MultipleInstancesClose_Regression(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Simulate multiple plugin instances getting cache references
|
||||
instances := make([]*CacheInterfaceWrapper, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
instances[i] = wrapper
|
||||
|
||||
// Each instance might set some data
|
||||
cache.Set(fmt.Sprintf("instance-%d-key", i), fmt.Sprintf("value-%d", i), time.Hour)
|
||||
}
|
||||
|
||||
// Close all instances (simulating plugin shutdown/reload)
|
||||
for _, wrapper := range instances {
|
||||
wrapper.Close()
|
||||
}
|
||||
|
||||
// The shared cache should still work after all instances closed their wrappers
|
||||
newCache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Data set by earlier instances should still be accessible
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("instance-%d-key", i)
|
||||
value, found := newCache.Get(key)
|
||||
if !found {
|
||||
t.Errorf("Expected data from instance %d to still be accessible", i)
|
||||
}
|
||||
expectedValue := fmt.Sprintf("value-%d", i)
|
||||
if value != expectedValue {
|
||||
t.Errorf("Expected '%s', got '%v'", expectedValue, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Should be able to add new data
|
||||
newCache.Set("after-close-key", "after-close-value", time.Hour)
|
||||
value, found := newCache.Get("after-close-key")
|
||||
if !found || value != "after-close-value" {
|
||||
t.Error("Expected to be able to use cache after all wrapper Close() calls")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllSharedCachesMarkedAsManaged verifies all shared cache getters
|
||||
// return managed wrappers to prevent the log flooding issue.
|
||||
func TestAllSharedCachesMarkedAsManaged(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cache CacheInterface
|
||||
}{
|
||||
{"TokenBlacklist", cm.GetSharedTokenBlacklist()},
|
||||
{"IntrospectionCache", cm.GetSharedIntrospectionCache()},
|
||||
{"TokenTypeCache", cm.GetSharedTokenTypeCache()},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
wrapper, ok := tt.cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatalf("Expected CacheInterfaceWrapper for %s", tt.name)
|
||||
}
|
||||
if !wrapper.managed {
|
||||
t.Errorf("%s cache wrapper should be marked as managed", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
// MarshalJSON implements custom JSON marshaling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
@@ -47,7 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
// MarshalYAML implements custom YAML marshaling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalYAML() (interface{}, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("old-access-token")
|
||||
session.SetRefreshToken("old-refresh-token")
|
||||
session.SetIDToken("old-id-token")
|
||||
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Now perform selective clearing (as done in the fix)
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// Set initial session data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-token")
|
||||
session.SetCSRF("existing-csrf")
|
||||
|
||||
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
|
||||
// NEW BEHAVIOR: Selective clearing
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
|
||||
@@ -0,0 +1,290 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/dcrstorage"
|
||||
)
|
||||
|
||||
// DCRStorageBackend represents the type of storage backend for DCR credentials.
|
||||
// Alias for internal package type for backward compatibility.
|
||||
type DCRStorageBackend = dcrstorage.StorageBackend
|
||||
|
||||
const (
|
||||
// DCRStorageBackendFile uses file-based storage (default for backward compatibility)
|
||||
DCRStorageBackendFile DCRStorageBackend = dcrstorage.StorageBackendFile
|
||||
|
||||
// DCRStorageBackendRedis uses Redis for distributed storage
|
||||
DCRStorageBackendRedis DCRStorageBackend = dcrstorage.StorageBackendRedis
|
||||
|
||||
// DCRStorageBackendAuto automatically selects Redis if available, otherwise file
|
||||
DCRStorageBackendAuto DCRStorageBackend = dcrstorage.StorageBackendAuto
|
||||
)
|
||||
|
||||
// DCRCredentialsStore defines the interface for storing DCR credentials.
|
||||
// This abstraction allows different storage backends (file, Redis) to be used
|
||||
// for persisting OIDC Dynamic Client Registration credentials across nodes.
|
||||
type DCRCredentialsStore interface {
|
||||
// Save stores the client registration response for a provider
|
||||
// The providerURL is used as a key to support multi-tenant scenarios
|
||||
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
|
||||
|
||||
// Load retrieves stored credentials for a provider
|
||||
// Returns nil, nil if no credentials exist (not an error)
|
||||
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
|
||||
|
||||
// Delete removes stored credentials for a provider
|
||||
Delete(ctx context.Context, providerURL string) error
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
Exists(ctx context.Context, providerURL string) (bool, error)
|
||||
}
|
||||
|
||||
// loggerAdapter adapts our Logger to the dcrstorage.Logger interface
|
||||
type loggerAdapter struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
func (l *loggerAdapter) Debug(msg string) { l.logger.Debug("%s", msg) }
|
||||
func (l *loggerAdapter) Debugf(format string, args ...any) { l.logger.Debugf(format, args...) }
|
||||
func (l *loggerAdapter) Info(msg string) { l.logger.Info("%s", msg) }
|
||||
func (l *loggerAdapter) Infof(format string, args ...any) { l.logger.Infof(format, args...) }
|
||||
func (l *loggerAdapter) Error(msg string) { l.logger.Error("%s", msg) }
|
||||
func (l *loggerAdapter) Errorf(format string, args ...any) { l.logger.Errorf(format, args...) }
|
||||
|
||||
// cacheAdapter adapts UniversalCache to dcrstorage.Cache interface
|
||||
type cacheAdapter struct {
|
||||
cache *UniversalCache
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Get(key string) (any, bool) {
|
||||
return c.cache.Get(key)
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Set(key string, value any, ttl time.Duration) error {
|
||||
return c.cache.Set(key, value, ttl)
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Delete(key string) {
|
||||
c.cache.Delete(key)
|
||||
}
|
||||
|
||||
// fileStoreWrapper wraps dcrstorage.FileStore to implement DCRCredentialsStore
|
||||
type fileStoreWrapper struct {
|
||||
inner *dcrstorage.FileStore
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
innerCreds := convertCredsToInternal(creds)
|
||||
return w.inner.Save(ctx, providerURL, innerCreds)
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
innerCreds, err := w.inner.Load(ctx, providerURL)
|
||||
if err != nil || innerCreds == nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertCredsFromInternal(innerCreds), nil
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Delete(ctx context.Context, providerURL string) error {
|
||||
return w.inner.Delete(ctx, providerURL)
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
return w.inner.Exists(ctx, providerURL)
|
||||
}
|
||||
|
||||
// basePath returns the base path used for storing credentials (for backward compatibility in tests)
|
||||
func (w *fileStoreWrapper) basePath() string {
|
||||
return w.inner.BasePath()
|
||||
}
|
||||
|
||||
// getFilePath returns the file path for storing credentials for a specific provider (for backward compatibility in tests)
|
||||
func (w *fileStoreWrapper) getFilePath(providerURL string) string {
|
||||
return w.inner.GetFilePath(providerURL)
|
||||
}
|
||||
|
||||
// redisStoreWrapper wraps dcrstorage.RedisStore to implement DCRCredentialsStore
|
||||
type redisStoreWrapper struct {
|
||||
inner *dcrstorage.RedisStore
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
innerCreds := convertCredsToInternal(creds)
|
||||
return w.inner.Save(ctx, providerURL, innerCreds)
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
innerCreds, err := w.inner.Load(ctx, providerURL)
|
||||
if err != nil || innerCreds == nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertCredsFromInternal(innerCreds), nil
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Delete(ctx context.Context, providerURL string) error {
|
||||
return w.inner.Delete(ctx, providerURL)
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
return w.inner.Exists(ctx, providerURL)
|
||||
}
|
||||
|
||||
// FileCredentialsStore implements DCRCredentialsStore using file-based storage.
|
||||
// This is the default storage backend for backward compatibility with existing deployments.
|
||||
type FileCredentialsStore = fileStoreWrapper
|
||||
|
||||
// RedisCredentialsStore implements DCRCredentialsStore using Redis-backed cache.
|
||||
// This storage backend enables sharing DCR credentials across multiple Traefik instances.
|
||||
type RedisCredentialsStore = redisStoreWrapper
|
||||
|
||||
// NewFileCredentialsStore creates a new file-based credentials store.
|
||||
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
|
||||
func NewFileCredentialsStore(basePath string, logger *Logger) *FileCredentialsStore {
|
||||
var dcrLogger dcrstorage.Logger
|
||||
if logger != nil {
|
||||
dcrLogger = &loggerAdapter{logger: logger}
|
||||
}
|
||||
inner := dcrstorage.NewFileStore(basePath, dcrLogger)
|
||||
return &fileStoreWrapper{inner: inner}
|
||||
}
|
||||
|
||||
// NewRedisCredentialsStore creates a new Redis-backed credentials store.
|
||||
// The cache should be configured with a Redis backend for distributed storage.
|
||||
// If keyPrefix is empty, defaults to "dcr:creds:"
|
||||
func NewRedisCredentialsStore(cache *UniversalCache, keyPrefix string, logger *Logger) *RedisCredentialsStore {
|
||||
var dcrLogger dcrstorage.Logger
|
||||
if logger != nil {
|
||||
dcrLogger = &loggerAdapter{logger: logger}
|
||||
}
|
||||
cacheAdapt := &cacheAdapter{cache: cache}
|
||||
inner := dcrstorage.NewRedisStore(cacheAdapt, keyPrefix, dcrLogger)
|
||||
return &redisStoreWrapper{inner: inner}
|
||||
}
|
||||
|
||||
// Helper functions to convert between main package and internal package types
|
||||
func convertCredsToInternal(creds *ClientRegistrationResponse) *dcrstorage.ClientRegistrationResponse {
|
||||
if creds == nil {
|
||||
return nil
|
||||
}
|
||||
return &dcrstorage.ClientRegistrationResponse{
|
||||
SubjectType: creds.SubjectType,
|
||||
LogoURI: creds.LogoURI,
|
||||
RegistrationAccessToken: creds.RegistrationAccessToken,
|
||||
RegistrationClientURI: creds.RegistrationClientURI,
|
||||
Scope: creds.Scope,
|
||||
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
|
||||
TOSURI: creds.TOSURI,
|
||||
PolicyURI: creds.PolicyURI,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
ApplicationType: creds.ApplicationType,
|
||||
ClientID: creds.ClientID,
|
||||
ClientName: creds.ClientName,
|
||||
JWKSURI: creds.JWKSURI,
|
||||
ClientURI: creds.ClientURI,
|
||||
Contacts: creds.Contacts,
|
||||
GrantTypes: creds.GrantTypes,
|
||||
ResponseTypes: creds.ResponseTypes,
|
||||
RedirectURIs: creds.RedirectURIs,
|
||||
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
|
||||
ClientIDIssuedAt: creds.ClientIDIssuedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func convertCredsFromInternal(creds *dcrstorage.ClientRegistrationResponse) *ClientRegistrationResponse {
|
||||
if creds == nil {
|
||||
return nil
|
||||
}
|
||||
return &ClientRegistrationResponse{
|
||||
SubjectType: creds.SubjectType,
|
||||
LogoURI: creds.LogoURI,
|
||||
RegistrationAccessToken: creds.RegistrationAccessToken,
|
||||
RegistrationClientURI: creds.RegistrationClientURI,
|
||||
Scope: creds.Scope,
|
||||
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
|
||||
TOSURI: creds.TOSURI,
|
||||
PolicyURI: creds.PolicyURI,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
ApplicationType: creds.ApplicationType,
|
||||
ClientID: creds.ClientID,
|
||||
ClientName: creds.ClientName,
|
||||
JWKSURI: creds.JWKSURI,
|
||||
ClientURI: creds.ClientURI,
|
||||
Contacts: creds.Contacts,
|
||||
GrantTypes: creds.GrantTypes,
|
||||
ResponseTypes: creds.ResponseTypes,
|
||||
RedirectURIs: creds.RedirectURIs,
|
||||
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
|
||||
ClientIDIssuedAt: creds.ClientIDIssuedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDCRCredentialsStore creates a DCRCredentialsStore based on configuration.
|
||||
// This factory function handles backend selection logic:
|
||||
// - "file": Use file-based storage (default for backward compatibility)
|
||||
// - "redis": Use Redis exclusively (fails if Redis unavailable)
|
||||
// - "auto": Use Redis if available, fallback to file
|
||||
func NewDCRCredentialsStore(
|
||||
config *DynamicClientRegistrationConfig,
|
||||
cacheManager *CacheManager,
|
||||
logger *Logger,
|
||||
) (DCRCredentialsStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("DCR config is nil")
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
backend := config.StorageBackend
|
||||
if backend == "" {
|
||||
backend = string(DCRStorageBackendAuto) // Default to auto selection
|
||||
}
|
||||
|
||||
switch DCRStorageBackend(backend) {
|
||||
case DCRStorageBackendFile:
|
||||
logger.Info("Using file-based storage for DCR credentials")
|
||||
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
|
||||
|
||||
case DCRStorageBackendRedis:
|
||||
cache := getDCRCache(cacheManager)
|
||||
if cache == nil {
|
||||
return nil, fmt.Errorf("redis storage requested but Redis/cache not configured")
|
||||
}
|
||||
logger.Info("Using Redis storage for DCR credentials")
|
||||
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
|
||||
|
||||
case DCRStorageBackendAuto:
|
||||
// Try Redis first, fallback to file
|
||||
cache := getDCRCache(cacheManager)
|
||||
if cache != nil && cache.backend != nil {
|
||||
logger.Info("Auto-selected Redis storage for DCR credentials")
|
||||
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
|
||||
}
|
||||
logger.Info("Redis not available, using file storage for DCR credentials")
|
||||
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown DCR storage backend: %s", backend)
|
||||
}
|
||||
}
|
||||
|
||||
// getDCRCache safely retrieves the DCR credentials cache from the cache manager
|
||||
func getDCRCache(cacheManager *CacheManager) *UniversalCache {
|
||||
if cacheManager == nil {
|
||||
return nil
|
||||
}
|
||||
cacheManager.mu.RLock()
|
||||
defer cacheManager.mu.RUnlock()
|
||||
|
||||
if cacheManager.manager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cacheManager.manager.GetDCRCredentialsCache()
|
||||
}
|
||||
@@ -0,0 +1,663 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestFileCredentialsStore_SaveLoad tests the file-based credentials store
|
||||
func TestFileCredentialsStore_SaveLoad(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp directory for test files
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "test-access-token",
|
||||
RegistrationClientURI: "https://example.com/register/test-client-id",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||||
ResponseTypes: []string{"code"},
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
// Save credentials
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
// Load credentials
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
|
||||
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
tempDir2 := t.TempDir()
|
||||
store2 := NewFileCredentialsStore(filepath.Join(tempDir2, "nonexistent.json"), logger)
|
||||
|
||||
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent file: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
|
||||
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Error("Expected credentials to not exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete non-existent credentials", func(t *testing.T) {
|
||||
// Should not error
|
||||
err := store.Delete(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete should not error for non-existent: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_MultiProvider tests multi-provider support
|
||||
func TestFileCredentialsStore_MultiProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
provider1 := "https://auth1.example.com"
|
||||
provider2 := "https://auth2.example.com"
|
||||
|
||||
creds1 := &ClientRegistrationResponse{
|
||||
ClientID: "client-1",
|
||||
ClientSecret: "secret-1",
|
||||
}
|
||||
creds2 := &ClientRegistrationResponse{
|
||||
ClientID: "client-2",
|
||||
ClientSecret: "secret-2",
|
||||
}
|
||||
|
||||
// Save credentials for both providers
|
||||
if err := store.Save(ctx, provider1, creds1); err != nil {
|
||||
t.Fatalf("Failed to save creds1: %v", err)
|
||||
}
|
||||
if err := store.Save(ctx, provider2, creds2); err != nil {
|
||||
t.Fatalf("Failed to save creds2: %v", err)
|
||||
}
|
||||
|
||||
// Load and verify each provider's credentials
|
||||
loaded1, err := store.Load(ctx, provider1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds1: %v", err)
|
||||
}
|
||||
if loaded1.ClientID != "client-1" {
|
||||
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
|
||||
}
|
||||
|
||||
loaded2, err := store.Load(ctx, provider2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds2: %v", err)
|
||||
}
|
||||
if loaded2.ClientID != "client-2" {
|
||||
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
|
||||
}
|
||||
|
||||
// Delete one shouldn't affect the other
|
||||
if err := store.Delete(ctx, provider1); err != nil {
|
||||
t.Fatalf("Failed to delete creds1: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, provider2)
|
||||
if !exists {
|
||||
t.Error("Provider 2 credentials should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_ConcurrentAccess tests thread safety
|
||||
func TestFileCredentialsStore_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
concurrency := 10
|
||||
|
||||
// Concurrent saves
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = store.Save(ctx, providerURL, creds)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Concurrent loads
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = store.Load(ctx, providerURL)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Final verification
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after concurrent access: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test-client" {
|
||||
t.Error("Credentials corrupted after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_InvalidInput tests error handling
|
||||
func TestFileCredentialsStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty provider URL uses default path", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
err := store.Save(ctx, "", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Save with empty provider URL failed: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Load with empty provider URL failed: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials with empty provider URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_DefaultPath tests default path behavior
|
||||
func TestFileCredentialsStore_DefaultPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore("", logger)
|
||||
|
||||
// Just verify we can create with empty path and it has a default
|
||||
if store.basePath() == "" {
|
||||
t.Error("Expected default base path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_WithMemoryCache tests Redis store with in-memory cache
|
||||
func TestRedisCredentialsStore_WithMemoryCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create an in-memory cache for testing
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "redis-test-client",
|
||||
ClientSecret: "redis-test-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "redis-test-token",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
}
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_TTLFromExpiry tests TTL calculation
|
||||
func TestRedisCredentialsStore_TTLFromExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("expired credentials should fail", func(t *testing.T) {
|
||||
expiredCreds := &ClientRegistrationResponse{
|
||||
ClientID: "expired-client",
|
||||
ClientSecret: "expired-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), // Already expired
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "no-expiry-client",
|
||||
ClientSecret: "no-expiry-secret",
|
||||
ClientSecretExpiresAt: 0, // No expiry
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://noexpiry.example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials without expiry: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_InvalidInput tests error handling
|
||||
func TestRedisCredentialsStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDCRStorageFactory tests the factory function
|
||||
func TestDCRStorageFactory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("nil config returns error", func(t *testing.T) {
|
||||
_, err := NewDCRCredentialsStore(nil, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("file backend creates file store", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "file",
|
||||
CredentialsFile: "/tmp/test-creds.json",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create file store: %v", err)
|
||||
}
|
||||
if store == nil {
|
||||
t.Error("Expected store but got nil")
|
||||
}
|
||||
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("redis backend without cache manager returns error", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "redis",
|
||||
}
|
||||
|
||||
_, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for redis backend without cache manager")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auto backend without redis falls back to file", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "auto",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auto store: %v", err)
|
||||
}
|
||||
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore for auto without redis")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown backend returns error", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "unknown",
|
||||
}
|
||||
|
||||
_, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for unknown backend")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty backend defaults to auto", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create store with empty backend: %v", err)
|
||||
}
|
||||
|
||||
// Should default to file (auto without redis)
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore for empty backend")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDynamicClientRegistrar_WithStore tests registrar with store
|
||||
func TestDynamicClientRegistrar_WithStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrarWithStore(
|
||||
nil, // httpClient
|
||||
logger,
|
||||
config,
|
||||
"https://auth.example.com",
|
||||
store,
|
||||
)
|
||||
|
||||
if registrar == nil {
|
||||
t.Fatal("Expected registrar but got nil")
|
||||
}
|
||||
|
||||
if registrar.store == nil {
|
||||
t.Error("Expected store to be set")
|
||||
}
|
||||
|
||||
// Test SetStore
|
||||
newStore := NewFileCredentialsStore(filepath.Join(tempDir, "new.json"), logger)
|
||||
registrar.SetStore(newStore)
|
||||
|
||||
if registrar.store != newStore {
|
||||
t.Error("SetStore did not update the store")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicClientRegistrar_CredentialsFromStore tests loading from store
|
||||
func TestDynamicClientRegistrar_CredentialsFromStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
providerURL := "https://auth.example.com"
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-save credentials
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "pre-saved-client",
|
||||
ClientSecret: "pre-saved-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
}
|
||||
if err := store.Save(ctx, providerURL, testCreds); err != nil {
|
||||
t.Fatalf("Failed to pre-save credentials: %v", err)
|
||||
}
|
||||
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrarWithStore(
|
||||
nil,
|
||||
logger,
|
||||
config,
|
||||
providerURL,
|
||||
store,
|
||||
)
|
||||
|
||||
// Test loading via the internal method
|
||||
loaded, err := registrar.loadCredentialsFromStore(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load from store: %v", err)
|
||||
}
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != "pre-saved-client" {
|
||||
t.Errorf("ClientID mismatch: got %s", loaded.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_CorruptedFile tests handling of corrupted files
|
||||
func TestFileCredentialsStore_CorruptedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
// Write corrupted JSON
|
||||
filePath := store.getFilePath(providerURL)
|
||||
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write corrupted file: %v", err)
|
||||
}
|
||||
|
||||
// Should return error for corrupted file
|
||||
_, err := store.Load(ctx, providerURL)
|
||||
if err == nil {
|
||||
t.Error("Expected error for corrupted JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_DirectoryCreation tests auto directory creation
|
||||
func TestFileCredentialsStore_DirectoryCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(deepPath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
|
||||
err := store.Save(ctx, "https://example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save with nested directory: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after nested directory creation: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials from nested directory")
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,10 @@ The **audience** (`aud`) claim in a JWT identifies the intended recipient of the
|
||||
|
||||
### Why Does This Matter?
|
||||
|
||||
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
|
||||
Audience validation rejects access tokens whose `aud` claim does not match the
|
||||
expected audience, blocking the trivial form of token confusion where a token
|
||||
issued for API A is presented to API B. (Defence in depth — pair with
|
||||
short-lived tokens, rotation, and per-API client credentials.)
|
||||
|
||||
---
|
||||
|
||||
@@ -137,8 +140,8 @@ http:
|
||||
**Recommended:** `true` for production
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
|
||||
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
|
||||
- When `true`: On audience mismatch, the middleware does **not** silently fall back to ID-token validation. It tries to refresh the access token first; if no refresh token is present (or refresh fails), the user is re-authenticated.
|
||||
- When `false`: Logs warnings and falls back to ID-token validation (backward compatible).
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
@@ -349,7 +352,7 @@ When opaque tokens are detected:
|
||||
|
||||
**Cache behavior:**
|
||||
- Cache key: Token hash
|
||||
- TTL: 5 minutes or token expiry (whichever is shorter)
|
||||
- TTL: 5 minutes; if the token's `exp` is sooner, the cache entry expires at `exp` instead. Tokens without `exp` use the flat 5-minute TTL.
|
||||
- Reduces introspection requests for frequently used tokens
|
||||
|
||||
---
|
||||
|
||||
+68
-7
@@ -52,7 +52,7 @@ spec:
|
||||
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
|
||||
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
|
||||
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
|
||||
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
|
||||
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
|
||||
| `rateLimit` | int | `100` | Maximum requests per second |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication |
|
||||
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
|
||||
@@ -62,13 +62,40 @@ spec:
|
||||
|
||||
### TLS Termination at Load Balancer
|
||||
|
||||
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
|
||||
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
|
||||
the correct default behind any TLS-terminating load balancer (AWS ALB, Google
|
||||
Cloud LB, Azure App Gateway) — `X-Forwarded-Proto` cannot be trusted (ALB may
|
||||
overwrite it).
|
||||
|
||||
```yaml
|
||||
forceHTTPS: true # Required for correct redirect URIs
|
||||
```
|
||||
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
|
||||
dev). Otherwise leave it at default.
|
||||
|
||||
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
|
||||
### Streaming Endpoints (SSE and WebSocket)
|
||||
|
||||
The middleware automatically bypasses the OIDC redirect for two request kinds
|
||||
that browsers cannot follow a 302 on:
|
||||
|
||||
| Bypass | Triggered by |
|
||||
|--------|--------------|
|
||||
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
|
||||
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
|
||||
|
||||
These requests do **not** require any explicit configuration — they are
|
||||
handled implicitly. However, the bypass is **not** unauthenticated:
|
||||
|
||||
- A valid, encrypted session cookie is required. Requests without one are
|
||||
rejected (the connection cannot proceed to the backend).
|
||||
- The session cookie is sealed with `sessionEncryptionKey`, so the
|
||||
`authenticated` flag cannot be forged.
|
||||
- Validation is cookie-only — no JWK fetch / signature verification — so
|
||||
streaming endpoints keep working when the OIDC provider is briefly
|
||||
unavailable.
|
||||
- The user identifier from the session is forwarded as `X-Forwarded-User`
|
||||
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
|
||||
|
||||
For browser clients, the user must complete the normal OIDC flow on a
|
||||
regular HTTP page first; the resulting session cookie is then reused on the
|
||||
SSE / WebSocket connection.
|
||||
|
||||
---
|
||||
|
||||
@@ -113,6 +140,7 @@ strictAudienceValidation: true
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
@@ -384,10 +412,14 @@ scopes:
|
||||
|
||||
### Dynamic Client Registration (RFC 7591)
|
||||
|
||||
Dynamic Client Registration allows the middleware to automatically register itself with the OIDC provider, eliminating the need to manually create client credentials.
|
||||
|
||||
**Basic Configuration (Single Instance):**
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
initialAccessToken: "your-token" # Optional
|
||||
initialAccessToken: "your-token" # Optional, if provider requires it
|
||||
persistCredentials: true
|
||||
credentialsFile: "/tmp/oidc-credentials.json"
|
||||
clientMetadata:
|
||||
@@ -400,6 +432,35 @@ dynamicClientRegistration:
|
||||
- "refresh_token"
|
||||
```
|
||||
|
||||
**Multi-Replica Deployment (Kubernetes):**
|
||||
|
||||
For Kubernetes deployments with multiple replicas, use Redis storage to share credentials across all instances and prevent registration race conditions:
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
persistCredentials: true
|
||||
storageBackend: "redis" # Share credentials via Redis
|
||||
redisKeyPrefix: "myapp:dcr:" # Optional custom prefix
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://your-app.com/oauth2/callback"
|
||||
client_name: "My Application"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "redis"
|
||||
```
|
||||
|
||||
**Storage Backend Options:**
|
||||
|
||||
| Backend | Description | Use Case |
|
||||
|---------|-------------|----------|
|
||||
| `file` | Store credentials in local file | Single instance deployments |
|
||||
| `redis` | Store credentials in Redis | Multi-replica Kubernetes deployments |
|
||||
| `auto` | Use Redis if available, fallback to file | Flexible deployments (default) |
|
||||
|
||||
### Multi-Replica Deployment
|
||||
|
||||
Without Redis, disable replay detection:
|
||||
|
||||
+95
@@ -0,0 +1,95 @@
|
||||
# Dynamic Client Registration (RFC 7591)
|
||||
|
||||
The middleware can register itself with an OIDC provider at startup instead of
|
||||
using a pre-provisioned `clientID` / `clientSecret`. Useful for multi-tenant
|
||||
deployments, self-service integrations, and ephemeral environments.
|
||||
|
||||
## How it works
|
||||
|
||||
1. Middleware reads `registration_endpoint` from `.well-known/openid-configuration`.
|
||||
2. If `clientID` is empty, it `POST`s `clientMetadata` to the registration endpoint.
|
||||
3. Returned `client_id` / `client_secret` are cached, optionally persisted.
|
||||
4. Subsequent requests use the registered credentials.
|
||||
|
||||
For multi-replica deployments, set `storageBackend: redis` so all replicas
|
||||
share one client and avoid registration races.
|
||||
|
||||
## Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-dcr
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://your-oidc-provider.com
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
persistCredentials: true
|
||||
storageBackend: redis # file | redis | auto
|
||||
initialAccessToken: "" # optional, for protected endpoints
|
||||
registrationEndpoint: "" # optional, override discovery
|
||||
credentialsFile: /tmp/oidc-client-credentials.json
|
||||
redisKeyPrefix: "dcr:creds:"
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- https://app.example.com/oauth2/callback
|
||||
client_name: My Application
|
||||
application_type: web
|
||||
grant_types: [authorization_code, refresh_token]
|
||||
response_types: [code]
|
||||
token_endpoint_auth_method: client_secret_basic
|
||||
contacts: [admin@example.com]
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `enabled` | `false` | Enable DCR. |
|
||||
| `persistCredentials` | `false` | Save returned credentials for reuse across restarts. |
|
||||
| `storageBackend` | `auto` | `file`, `redis`, or `auto` (Redis if available, else file). |
|
||||
| `credentialsFile` | `/tmp/oidc-client-credentials.json` | Path for file-backed storage. Mode `0600`. |
|
||||
| `redisKeyPrefix` | (none — set explicitly) | Key prefix for Redis-backed storage. The code does not inject a default; if unset, keys have no prefix. `dcr:creds:` is a sensible convention. |
|
||||
| `registrationEndpoint` | discovered | Override the discovered endpoint. |
|
||||
| `initialAccessToken` | none | Bearer token for protected registration endpoints. |
|
||||
| `clientMetadata.redirect_uris` | required | Callback URIs for the OAuth flow. |
|
||||
| `clientMetadata.client_name` | none | Human-readable client name. |
|
||||
| `clientMetadata.application_type` | `web` | `web` or `native`. |
|
||||
| `clientMetadata.grant_types` | `[authorization_code, refresh_token]` | OAuth grant types. |
|
||||
| `clientMetadata.response_types` | `[code]` | OAuth response types. |
|
||||
| `clientMetadata.token_endpoint_auth_method` | `client_secret_basic` | `client_secret_basic`, `client_secret_post`, or `none`. |
|
||||
| `clientMetadata.scope` | none | Space-separated scopes. |
|
||||
| `clientMetadata.contacts` | none | Admin email addresses. |
|
||||
| `clientMetadata.logo_uri` | none | Logo URL for consent screens. |
|
||||
| `clientMetadata.client_uri` | none | Client homepage URL. |
|
||||
| `clientMetadata.policy_uri` | none | Privacy policy URL. |
|
||||
| `clientMetadata.tos_uri` | none | Terms of service URL. |
|
||||
|
||||
## Provider support
|
||||
|
||||
The middleware does not gate DCR by provider — if the provider exposes a
|
||||
`registration_endpoint` in its discovery document (or you set
|
||||
`registrationEndpoint` explicitly), DCR will attempt registration. The table
|
||||
below is informational guidance based on each provider's published support.
|
||||
|
||||
| Provider | DCR | Notes |
|
||||
|----------|-----|-------|
|
||||
| Keycloak | Yes | Enable in realm settings. |
|
||||
| Auth0 | Yes | Requires Management API token. |
|
||||
| Okta | Yes | Enable Dynamic Client Registration in admin console. |
|
||||
| Azure AD | Limited | Use App Registration API instead. |
|
||||
| Google | No | Manual registration required. |
|
||||
| AWS Cognito | No | Manual registration required. |
|
||||
|
||||
## Security notes
|
||||
|
||||
- Registration endpoints must be HTTPS (loopback excepted for local dev).
|
||||
- Use `initialAccessToken` in production to gate registration.
|
||||
- File-backed credentials use `0600`; protect the mount path.
|
||||
- The plugin marks credentials invalid when within ~5 min of `client_secret_expires_at` but does **not** automatically re-register. If your provider sets a non-zero expiry, schedule manual rotation (delete the credentials file or Redis entry, restart) before that time.
|
||||
+20
-99
@@ -16,9 +16,8 @@ Guide for local development, testing, and contributing to the Traefik OIDC middl
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Go 1.23+** for plugin compilation
|
||||
- **Docker & Docker Compose** for local testing
|
||||
- **OIDC Provider** credentials (Google, Azure, etc.)
|
||||
- **Go 1.24+** (matches `go.mod`; CI runs Go 1.24.11)
|
||||
- **OIDC Provider** credentials (Google, Azure, etc.) for any end-to-end test against a real provider
|
||||
|
||||
### Required Development Tools
|
||||
|
||||
@@ -40,110 +39,32 @@ go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
## Local Development Setup
|
||||
|
||||
### Docker Compose Environment
|
||||
|
||||
The repository includes a Docker Compose setup for testing the plugin locally.
|
||||
|
||||
#### 1. Host Configuration
|
||||
|
||||
Add to `/etc/hosts`:
|
||||
### Build and unit tests
|
||||
|
||||
```bash
|
||||
127.0.0.1 hello.localhost
|
||||
127.0.0.1 traefik.localhost
|
||||
go mod tidy
|
||||
go build ./...
|
||||
go test ./... -short # fast loop, < 30 s
|
||||
go test -race -timeout=15m ./...
|
||||
```
|
||||
|
||||
#### 2. Plugin Configuration
|
||||
### Sample plugin configurations
|
||||
|
||||
The plugin is loaded using Traefik's **local plugins mode**:
|
||||
Working middleware/Traefik configs live in [`examples/`](../examples/):
|
||||
|
||||
- Plugin source: Parent directory (`../`)
|
||||
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
|
||||
- Configuration: `experimental.localPlugins` in `traefik.yml`
|
||||
- `complete-traefik-config.yaml` — full middleware example
|
||||
- `redis-config.yaml` — Redis cache configuration
|
||||
|
||||
#### 3. OIDC Provider Setup
|
||||
To run the plugin against a real Traefik instance, drop the project on disk
|
||||
and load it via `experimental.localPlugins` in your Traefik static config —
|
||||
see the [README install section](../README.md#install).
|
||||
|
||||
Edit `docker/dynamic.yml` with your provider details:
|
||||
### Integration tests
|
||||
|
||||
**Google:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://accounts.google.com"
|
||||
clientID: "your-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
logoutURL: "/oauth2/logout"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
**Azure AD:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
#### 4. Start Environment
|
||||
Integration tests live in `integration/`. Run them explicitly:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### 5. Test Plugin
|
||||
|
||||
- **Protected App**: http://hello.localhost (redirects to OIDC)
|
||||
- **Traefik Dashboard**: http://traefik.localhost:8080
|
||||
|
||||
### Development Workflow
|
||||
|
||||
1. **Edit plugin code** in the project root
|
||||
2. **Build and test** (optional syntax check):
|
||||
```bash
|
||||
go mod tidy
|
||||
go build .
|
||||
go test ./...
|
||||
```
|
||||
3. **Restart Traefik** to reload plugin:
|
||||
```bash
|
||||
docker-compose restart traefik
|
||||
```
|
||||
4. **Test changes** at http://hello.localhost
|
||||
|
||||
### Debugging
|
||||
|
||||
**View plugin logs:**
|
||||
```bash
|
||||
docker-compose logs -f traefik | grep traefikoidc
|
||||
```
|
||||
|
||||
**Check plugin loading:**
|
||||
```bash
|
||||
docker-compose logs traefik | grep -i plugin
|
||||
```
|
||||
|
||||
**Verify plugin directory:**
|
||||
```bash
|
||||
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
|
||||
go test ./integration/... -run Integration -v
|
||||
```
|
||||
|
||||
---
|
||||
@@ -299,7 +220,7 @@ The repository uses GitHub Actions for comprehensive validation with 20+ paralle
|
||||
|
||||
#### Testing (9 suites)
|
||||
- Race Detector
|
||||
- Coverage (75% threshold)
|
||||
- Coverage (70% threshold, enforced in `pr.yaml`)
|
||||
- Memory Leaks
|
||||
- Integration Tests
|
||||
- Regression Tests
|
||||
@@ -323,13 +244,13 @@ Tests run in parallel for:
|
||||
#### Performance & Build (3 checks)
|
||||
- Benchmarks
|
||||
- Multi-platform Build (linux/darwin x amd64/arm64)
|
||||
- Go Version Compatibility (Go 1.23 & 1.24)
|
||||
- Go Version Compatibility (currently Go 1.24.11 in CI)
|
||||
|
||||
### Quality Gates
|
||||
|
||||
All PRs must pass:
|
||||
- All parallel checks
|
||||
- 75% test coverage minimum
|
||||
- 70% test coverage minimum
|
||||
- Zero security vulnerabilities
|
||||
- No race conditions
|
||||
- No memory leaks
|
||||
|
||||
+5
-3
@@ -23,10 +23,10 @@ Configuration reference for each supported OIDC provider.
|
||||
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|
||||
|----------|-------------|----------------|----------------|-----------|
|
||||
| Google | Full | Yes | `accounts.google.com` | Yes |
|
||||
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
|
||||
| Azure AD | Full | Yes | `login.microsoftonline.com`, `sts.windows.net` | Yes |
|
||||
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
|
||||
| Okta | Full | Yes | `*.okta.com` | Yes |
|
||||
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
|
||||
| Okta | Full | Yes | `*.okta.com`, `*.oktapreview.com`, `*.okta-emea.com` | Yes |
|
||||
| Keycloak | Full | Yes | host containing `keycloak`, or `/realms/` in path (matches both `/auth/realms/` legacy and `/realms/` modern) | Yes |
|
||||
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
|
||||
| GitLab | Full | Yes | `gitlab.com` | Yes |
|
||||
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
|
||||
@@ -353,6 +353,8 @@ allowPrivateIPAddresses: true # Required for private IPs
|
||||
- Roles: User Client Role mapper with "Add to ID token" enabled
|
||||
- Groups: Group Membership mapper with "Add to ID token" enabled
|
||||
|
||||
See [KEYCLOAK_SETUP_GUIDE.md](KEYCLOAK_SETUP_GUIDE.md) for detailed step-by-step setup instructions, mapper configuration, troubleshooting, and performance optimization.
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
+14
-6
@@ -109,11 +109,11 @@ redis:
|
||||
| `writeTimeout` | int | `3` | Write timeout (seconds) |
|
||||
| `enableTLS` | bool | `false` | Enable TLS for connections |
|
||||
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
|
||||
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
|
||||
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
|
||||
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
|
||||
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
|
||||
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
|
||||
| `enableCircuitBreaker` | bool | `false` | Wrap the Redis backend with a circuit breaker. **Recommended `true` in production.** |
|
||||
| `circuitBreakerThreshold` | int | `5` | Consecutive failures before the circuit opens (only when `enableCircuitBreaker: true`). |
|
||||
| `circuitBreakerTimeout` | int | `60` | Seconds the circuit stays open before allowing a probe (only when `enableCircuitBreaker: true`). |
|
||||
| `enableHealthCheck` | bool | `false` | Wrap the Redis backend with periodic health checks. **Recommended `true` in production.** |
|
||||
| `healthCheckInterval` | int | `30` | Health check interval in seconds (only when `enableHealthCheck: true`). |
|
||||
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
|
||||
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
|
||||
|
||||
@@ -134,13 +134,21 @@ REDIS_READ_TIMEOUT=3
|
||||
REDIS_WRITE_TIMEOUT=3
|
||||
REDIS_ENABLE_TLS=false
|
||||
REDIS_TLS_SKIP_VERIFY=false
|
||||
REDIS_HYBRID_L1_SIZE=500
|
||||
REDIS_HYBRID_L1_MEMORY_MB=10
|
||||
```
|
||||
|
||||
> Resilience fields (`enableCircuitBreaker`, `enableHealthCheck`,
|
||||
> `circuitBreakerThreshold`, `circuitBreakerTimeout`, `healthCheckInterval`)
|
||||
> have no environment variable fallback — set them in plugin configuration.
|
||||
|
||||
Invalid `cacheMode` values are rejected at plugin startup.
|
||||
|
||||
---
|
||||
|
||||
## Cache Modes
|
||||
|
||||
### Memory Mode (Default without Redis)
|
||||
### Memory Mode (used when Redis is disabled)
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
|
||||
+2
-2
@@ -6,8 +6,8 @@ Comprehensive testing infrastructure for traefikoidc.
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Test files | 99 |
|
||||
| Lines of test code | ~65,500 |
|
||||
| Test files | 110 |
|
||||
| Lines of test code | ~72,000 |
|
||||
| Code coverage | 71.0% |
|
||||
| Race conditions | None (all pass with `-race`) |
|
||||
|
||||
|
||||
+121
-2
@@ -90,6 +90,7 @@
|
||||
<a href="#configuration" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Configuration</a>
|
||||
<a href="#deployment" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Deployment</a>
|
||||
<a href="#security" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Security</a>
|
||||
<a href="#logout" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Logout</a>
|
||||
</div>
|
||||
<div class="flex items-center space-x-4">
|
||||
<button id="theme-toggle" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 p-2 min-w-[44px] min-h-[44px] flex items-center justify-center" aria-label="Toggle theme">
|
||||
@@ -114,6 +115,7 @@
|
||||
<a href="#configuration" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Configuration</a>
|
||||
<a href="#deployment" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Deployment</a>
|
||||
<a href="#security" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Security</a>
|
||||
<a href="#logout" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Logout</a>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
@@ -193,7 +195,7 @@
|
||||
</div>
|
||||
<div>
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-1">Dynamic Registration</h3>
|
||||
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration for automatic client setup without manual configuration</p>
|
||||
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration with Redis storage support for multi-replica deployments</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -716,6 +718,11 @@ spec:
|
||||
<td class="py-2 px-3">86400</td>
|
||||
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
|
||||
<td class="py-2 px-3">21600</td>
|
||||
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
|
||||
<td class="py-2 px-3">_oidc_raczylo_</td>
|
||||
@@ -856,7 +863,54 @@ spec:
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Dynamic Client Registration (RFC 7591)</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Automatically register your application with the OIDC provider. Supports Redis storage for multi-replica deployments:</p>
|
||||
<div class="overflow-x-auto mb-4">
|
||||
<table class="w-full text-sm">
|
||||
<thead>
|
||||
<tr class="border-b border-gray-200 dark:border-gray-700">
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Parameter</th>
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Default</th>
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Description</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="text-gray-600 dark:text-gray-400">
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.enabled</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable dynamic client registration</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.persistCredentials</code></td>
|
||||
<td class="py-2 px-3">true</td>
|
||||
<td class="py-2 px-3">Persist registered credentials across restarts</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.storageBackend</code></td>
|
||||
<td class="py-2 px-3">auto</td>
|
||||
<td class="py-2 px-3">Storage backend: <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">file</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis</code>, or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">auto</code> (uses Redis if available)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.redisKeyPrefix</code></td>
|
||||
<td class="py-2 px-3">dcr:creds:</td>
|
||||
<td class="py-2 px-3">Redis key prefix for DCR credentials</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.clientMetadata.redirect_uris</code></td>
|
||||
<td class="py-2 px-3">-</td>
|
||||
<td class="py-2 px-3">Redirect URIs for the registered client (required)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -1177,6 +1231,71 @@ spec:
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- IdP-Initiated Logout Section -->
|
||||
<section id="logout" class="py-12 sm:py-16 md:py-20 bg-white dark:bg-gray-900 theme-transition">
|
||||
<div class="max-w-6xl mx-auto px-4 sm:px-6">
|
||||
<div class="text-center mb-8 sm:mb-12">
|
||||
<h2 class="text-2xl sm:text-3xl md:text-4xl font-bold text-gray-900 dark:text-gray-100 mb-3 sm:mb-4">IdP-Initiated Logout</h2>
|
||||
<p class="text-base sm:text-lg text-gray-600 dark:text-gray-300 px-4">Support for OIDC Back-Channel and Front-Channel Logout specifications</p>
|
||||
</div>
|
||||
<div class="max-w-4xl mx-auto">
|
||||
<div class="grid md:grid-cols-2 gap-6 mb-8">
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
|
||||
<i class="fas fa-server mr-2 text-blue-500"></i>
|
||||
Back-Channel Logout
|
||||
</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
|
||||
Server-to-server logout notification. The IdP sends a signed JWT (logout_token) directly to your application when a user logs out.
|
||||
</p>
|
||||
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
|
||||
<li>• Signed JWT logout tokens</li>
|
||||
<li>• Session ID (sid) based invalidation</li>
|
||||
<li>• Subject (sub) based invalidation</li>
|
||||
<li>• Works behind firewalls</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
|
||||
<i class="fas fa-browser mr-2 text-purple-500"></i>
|
||||
Front-Channel Logout
|
||||
</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
|
||||
Browser-based logout via iframe. The IdP embeds an iframe pointing to your logout endpoint during user logout.
|
||||
</p>
|
||||
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
|
||||
<li>• Iframe-based session termination</li>
|
||||
<li>• Immediate cookie invalidation</li>
|
||||
<li>• Simple GET request handling</li>
|
||||
<li>• Issuer validation</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Configuration Example</h3>
|
||||
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# ... other OIDC configuration ...
|
||||
|
||||
# Back-Channel Logout (server-to-server)
|
||||
enableBackchannelLogout: true
|
||||
backchannelLogoutURL: "/backchannel-logout"
|
||||
|
||||
# Front-Channel Logout (browser-based)
|
||||
enableFrontchannelLogout: true
|
||||
frontchannelLogoutURL: "/frontchannel-logout"</code></pre>
|
||||
<p class="text-gray-600 dark:text-gray-400 text-sm mt-4">
|
||||
Configure your IdP with the full URLs (e.g., <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">https://your-app.example.com/backchannel-logout</code>).
|
||||
When a user logs out from the IdP, all their sessions across your applications will be invalidated.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Why Choose Section -->
|
||||
<section class="py-12 sm:py-16 md:py-20 bg-gray-50 dark:bg-gray-800 theme-transition">
|
||||
<div class="max-w-6xl mx-auto px-4 sm:px-6">
|
||||
|
||||
@@ -50,6 +50,7 @@ type DynamicClientRegistrar struct {
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
registrationResponse *ClientRegistrationResponse
|
||||
store DCRCredentialsStore // Storage backend for credentials
|
||||
providerURL string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -73,8 +74,37 @@ func NewDynamicClientRegistrar(
|
||||
}
|
||||
}
|
||||
|
||||
// NewDynamicClientRegistrarWithStore creates a new dynamic client registrar with a specific storage backend
|
||||
func NewDynamicClientRegistrarWithStore(
|
||||
httpClient *http.Client,
|
||||
logger *Logger,
|
||||
dcrConfig *DynamicClientRegistrationConfig,
|
||||
providerURL string,
|
||||
store DCRCredentialsStore,
|
||||
) *DynamicClientRegistrar {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
return &DynamicClientRegistrar{
|
||||
httpClient: httpClient,
|
||||
logger: logger,
|
||||
config: dcrConfig,
|
||||
providerURL: providerURL,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// SetStore sets the credentials store for the registrar
|
||||
// This allows setting the store after creation when the cache manager is available
|
||||
func (r *DynamicClientRegistrar) SetStore(store DCRCredentialsStore) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.store = store
|
||||
}
|
||||
|
||||
// RegisterClient performs dynamic client registration with the OIDC provider
|
||||
// It first attempts to load existing credentials from a file if persistence is enabled,
|
||||
// It first attempts to load existing credentials from storage if persistence is enabled,
|
||||
// then registers a new client if no valid credentials exist.
|
||||
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
|
||||
if r.config == nil || !r.config.Enabled {
|
||||
@@ -83,10 +113,13 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
|
||||
|
||||
// Try to load existing credentials if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if resp, err := r.loadCredentials(); err == nil && resp != nil {
|
||||
resp, err := r.loadCredentialsFromStore(ctx)
|
||||
if err != nil {
|
||||
r.logger.Debugf("Failed to load credentials from store: %v", err)
|
||||
} else if resp != nil {
|
||||
// Check if credentials are still valid (not expired)
|
||||
if r.areCredentialsValid(resp) {
|
||||
r.logger.Info("Loaded existing client credentials from file")
|
||||
r.logger.Info("Loaded existing client credentials from storage")
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = resp
|
||||
r.mu.Unlock()
|
||||
@@ -179,7 +212,7 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
|
||||
|
||||
// Persist credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist client credentials: %v", err)
|
||||
// Don't fail registration if persistence fails
|
||||
}
|
||||
@@ -315,7 +348,44 @@ func (r *DynamicClientRegistrar) credentialsFilePath() string {
|
||||
return "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file
|
||||
// loadCredentialsFromStore loads client credentials from the configured storage backend
|
||||
// Falls back to legacy file-based loading if no store is configured
|
||||
func (r *DynamicClientRegistrar) loadCredentialsFromStore(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Load(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based loading
|
||||
return r.loadCredentials()
|
||||
}
|
||||
|
||||
// saveCredentialsToStore persists client credentials to the configured storage backend
|
||||
// Falls back to legacy file-based saving if no store is configured
|
||||
func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, resp *ClientRegistrationResponse) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Save(ctx, r.providerURL, resp)
|
||||
}
|
||||
// Fallback to legacy file-based saving
|
||||
return r.saveCredentials(resp)
|
||||
}
|
||||
|
||||
// deleteCredentialsFromStore removes credentials from the configured storage backend
|
||||
// Falls back to legacy file-based deletion if no store is configured
|
||||
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Delete(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based deletion
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
@@ -333,7 +403,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCredentials loads client credentials from a file
|
||||
// loadCredentials loads client credentials from a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
@@ -420,7 +490,7 @@ func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -527,11 +597,10 @@ func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) e
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials file if persistence is enabled
|
||||
// Remove credentials from storage if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
r.logger.Errorf("Failed to remove credentials file: %v", err)
|
||||
if err := r.deleteCredentialsFromStore(ctx); err != nil {
|
||||
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -40,6 +42,31 @@ func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, http
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
jwks, err := m.GetJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if jwks == nil {
|
||||
return nil, fmt.Errorf("JWKS is nil")
|
||||
}
|
||||
for i := range jwks.Keys {
|
||||
k := &jwks.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Cleanup() {
|
||||
atomic.AddInt32(&m.CleanupCalls, 1)
|
||||
m.mu.Lock()
|
||||
|
||||
+1
-1
@@ -954,7 +954,7 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
|
||||
gd.mutex.RLock()
|
||||
defer gd.mutex.RUnlock()
|
||||
|
||||
var degraded []string
|
||||
degraded := make([]string, 0, len(gd.degradedServices))
|
||||
for serviceName := range gd.degradedServices {
|
||||
degraded = append(degraded, serviceName)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
|
||||
@@ -12,8 +12,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
|
||||
+16
@@ -17,6 +17,21 @@ import (
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// newUUIDv4 returns an RFC 4122 v4 UUID string (e.g.
|
||||
// "f47ac10b-58cc-4372-a567-0e02b2c3d479") backed by crypto/rand. Used for CSRF
|
||||
// tokens and other opaque random identifiers — replaces github.com/google/uuid
|
||||
// to keep the plugin stdlib-only on the production path.
|
||||
func newUUIDv4() (string, error) {
|
||||
var b [16]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return "", fmt.Errorf("could not generate UUID: %w", err)
|
||||
}
|
||||
b[6] = (b[6] & 0x0f) | 0x40 // version 4
|
||||
b[8] = (b[8] & 0x3f) | 0x80 // RFC 4122 variant
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]), nil
|
||||
}
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
|
||||
// The nonce is used to prevent replay attacks and associate client sessions with ID tokens.
|
||||
// Returns:
|
||||
@@ -336,6 +351,7 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
// and redirects to the provider's logout endpoint or configured post-logout URI.
|
||||
// It handles potential errors during session retrieval or clearing.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
t.logger.Debug("Processing logout request")
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewUUIDv4 verifies the in-house UUID v4 generator produces RFC 4122
|
||||
// compliant identifiers. Locks in the replacement for github.com/google/uuid
|
||||
// — a regression here would weaken the CSRF token used in the OIDC flow.
|
||||
func TestNewUUIDv4(t *testing.T) {
|
||||
rfc4122v4 := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
|
||||
|
||||
const samples = 1000
|
||||
seen := make(map[string]struct{}, samples)
|
||||
for i := 0; i < samples; i++ {
|
||||
got, err := newUUIDv4()
|
||||
if err != nil {
|
||||
t.Fatalf("newUUIDv4 failed: %v", err)
|
||||
}
|
||||
if !rfc4122v4.MatchString(got) {
|
||||
t.Fatalf("UUID %q does not match RFC 4122 v4 format", got)
|
||||
}
|
||||
if _, dup := seen[got]; dup {
|
||||
t.Fatalf("duplicate UUID emitted within %d samples: %q", samples, got)
|
||||
}
|
||||
seen[got] = struct{}{}
|
||||
}
|
||||
}
|
||||
+13
-5
@@ -3,6 +3,7 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -25,10 +26,16 @@ type HTTPClientConfig struct {
|
||||
Timeout time.Duration
|
||||
MaxConnsPerHost int
|
||||
WriteBufferSize int
|
||||
UseCookieJar bool
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
// RootCAs is an optional certificate pool used for TLS verification.
|
||||
// A nil pool means "use the system trust store" (default behavior).
|
||||
RootCAs *x509.CertPool
|
||||
// InsecureSkipVerify disables TLS certificate verification.
|
||||
// ONLY set this for local development against self-signed certificates.
|
||||
InsecureSkipVerify bool
|
||||
UseCookieJar bool
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
}
|
||||
|
||||
// DefaultHTTPClientConfig returns the default configuration for general use
|
||||
@@ -203,7 +210,8 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
PreferServerCipherSuites: true,
|
||||
InsecureSkipVerify: false, // Always verify certificates
|
||||
RootCAs: config.RootCAs,
|
||||
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
|
||||
},
|
||||
ForceAttemptHTTP2: config.ForceHTTP2,
|
||||
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
|
||||
|
||||
+18
-3
@@ -3,6 +3,7 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -103,7 +104,8 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
PreferServerCipherSuites: true,
|
||||
InsecureSkipVerify: false,
|
||||
RootCAs: config.RootCAs,
|
||||
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
|
||||
},
|
||||
ForceAttemptHTTP2: config.ForceHTTP2,
|
||||
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
|
||||
@@ -205,8 +207,21 @@ func (p *SharedTransportPool) performCleanup() {
|
||||
|
||||
// configKey generates a unique key for a config
|
||||
func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
|
||||
// Simple key based on main parameters
|
||||
return string(rune(config.MaxConnsPerHost)) + string(rune(config.MaxIdleConnsPerHost))
|
||||
// Pool transports by the parameters that change TLS or connection
|
||||
// behavior. RootCAs and InsecureSkipVerify MUST be part of the key:
|
||||
// otherwise a middleware configured with a custom CA would share a
|
||||
// transport with one using the system store, silently bypassing its
|
||||
// CA configuration.
|
||||
skip := "0"
|
||||
if config.InsecureSkipVerify {
|
||||
skip = "1"
|
||||
}
|
||||
return fmt.Sprintf("%d|%d|%p|%s",
|
||||
config.MaxConnsPerHost,
|
||||
config.MaxIdleConnsPerHost,
|
||||
config.RootCAs,
|
||||
skip,
|
||||
)
|
||||
}
|
||||
|
||||
// Cleanup closes all transports and stops the cleanup goroutine
|
||||
|
||||
+14
-27
@@ -10,6 +10,14 @@ import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Pre-compiled regex patterns for validation (const patterns should use MustCompile)
|
||||
var (
|
||||
emailRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
urlRegexPattern = regexp.MustCompile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
|
||||
tokenRegexPattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)
|
||||
usernameRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
|
||||
)
|
||||
|
||||
// InputValidator provides comprehensive input validation and sanitization
|
||||
// to protect against common security vulnerabilities including SQL injection,
|
||||
// XSS, path traversal, and other injection attacks. It validates and sanitizes
|
||||
@@ -73,7 +81,7 @@ func DefaultInputValidationConfig() InputValidationConfig {
|
||||
}
|
||||
|
||||
// NewInputValidator creates a new input validator with the specified configuration.
|
||||
// It compiles all necessary regex patterns and initializes security pattern lists.
|
||||
// It uses pre-compiled regex patterns and initializes security pattern lists.
|
||||
//
|
||||
// Parameters:
|
||||
// - config: Validation configuration with size limits and mode settings.
|
||||
@@ -81,29 +89,8 @@ func DefaultInputValidationConfig() InputValidationConfig {
|
||||
//
|
||||
// Returns:
|
||||
// - A configured InputValidator instance.
|
||||
// - An error if regex compilation fails.
|
||||
// - An error (always nil, kept for API compatibility).
|
||||
func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputValidator, error) {
|
||||
// Compile regex patterns
|
||||
emailRegex, err := regexp.Compile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile email regex: %w", err)
|
||||
}
|
||||
|
||||
urlRegex, err := regexp.Compile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile URL regex: %w", err)
|
||||
}
|
||||
|
||||
tokenRegex, err := regexp.Compile(`^[A-Za-z0-9._-]+$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile token regex: %w", err)
|
||||
}
|
||||
|
||||
usernameRegex, err := regexp.Compile(`^[a-zA-Z0-9._-]+$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile username regex: %w", err)
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
@@ -112,10 +99,10 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
emailRegex: emailRegexPattern,
|
||||
urlRegex: urlRegexPattern,
|
||||
tokenRegex: tokenRegexPattern,
|
||||
usernameRegex: usernameRegexPattern,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
|
||||
Vendored
+3
@@ -24,6 +24,7 @@ type Config struct {
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
TLSServerName string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
@@ -34,6 +35,8 @@ type Config struct {
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
EnableMetrics bool
|
||||
EnableTLS bool
|
||||
TLSSkipVerify bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
|
||||
Vendored
+73
-26
@@ -20,6 +20,7 @@ type HybridBackend struct {
|
||||
ctx context.Context
|
||||
syncWriteCacheTypes map[string]bool
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
l1BackfillBuffer chan *l1BackfillItem
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
l1Hits atomic.Int64
|
||||
@@ -28,6 +29,7 @@ type HybridBackend struct {
|
||||
l1Writes atomic.Int64
|
||||
misses atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
l1BackfillDrops atomic.Int64
|
||||
fallbackMode atomic.Bool
|
||||
}
|
||||
|
||||
@@ -39,6 +41,15 @@ type asyncWriteItem struct {
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// l1BackfillItem represents a deferred write of an L2-resolved value back into
|
||||
// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot
|
||||
// detonate the goroutine count (issue: ~1000% CPU under sustained polling).
|
||||
type l1BackfillItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
@@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
@@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start L1 backfill worker (single goroutine) to bound goroutine growth on
|
||||
// L2 hits regardless of request rate.
|
||||
h.wg.Add(1)
|
||||
go h.l1BackfillWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
@@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
// Populate L1 cache with value from L2 (write-through on read).
|
||||
// Hand off to the bounded backfill worker instead of spawning a goroutine
|
||||
// per read - under burst that would mint thousands of goroutines.
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
@@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error {
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
close(h.l1BackfillBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
@@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt
|
||||
return nil
|
||||
}
|
||||
|
||||
// queueL1Backfill enqueues an L2-resolved value for write-through into L1.
|
||||
// Drops on full buffer to keep the read path constant-time; the next L2 hit
|
||||
// for the same key simply re-queues it.
|
||||
func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) {
|
||||
select {
|
||||
case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}:
|
||||
default:
|
||||
h.l1BackfillDrops.Add(1)
|
||||
h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// l1BackfillWorker drains the backfill queue serially. Single worker is
|
||||
// intentional - L1 writes are local and cheap, and serializing them keeps
|
||||
// goroutine count bounded under any read rate.
|
||||
func (h *HybridBackend) l1BackfillWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items best-effort then exit.
|
||||
for len(h.l1BackfillBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.l1BackfillBuffer:
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.primary.Set(writeCtx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.l1BackfillBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", item.key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does
|
||||
// not detonate the goroutine count. Pre-fix the code spawned one goroutine
|
||||
// per Get() L2 hit; post-fix all backfills funnel through a single worker.
|
||||
func TestHybridBackend_L1BackfillBounded(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 256,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
const burst = 1000
|
||||
|
||||
// Pre-populate L2 with `burst` distinct keys so each Get triggers a
|
||||
// fresh L1 backfill enqueue.
|
||||
for i := 0; i < burst; i++ {
|
||||
require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute))
|
||||
}
|
||||
|
||||
baseline := runtime.NumGoroutine()
|
||||
|
||||
// Issue the burst as fast as possible; the backfill worker MUST be the
|
||||
// only goroutine doing L1 writes. Allow brief slack for the test runtime
|
||||
// scheduling but anything north of +20 means goroutine leakage.
|
||||
peak := baseline
|
||||
for i := 0; i < burst; i++ {
|
||||
_, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
if g := runtime.NumGoroutine(); g > peak {
|
||||
peak = g
|
||||
}
|
||||
}
|
||||
|
||||
delta := peak - baseline
|
||||
if delta > 20 {
|
||||
t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines",
|
||||
delta, baseline, peak)
|
||||
}
|
||||
|
||||
// L1 must eventually catch up via the worker. Worker drains serially so
|
||||
// give it a generous window proportional to the burst size.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
var populated int
|
||||
for i := 0; i < burst; i++ {
|
||||
if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok {
|
||||
populated++
|
||||
}
|
||||
}
|
||||
// Be lenient: drops are acceptable under buffer pressure, just want
|
||||
// most of the keys to make it.
|
||||
if populated >= burst-int(hybrid.l1BackfillDrops.Load()) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d",
|
||||
hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load())
|
||||
}
|
||||
|
||||
// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the
|
||||
// buffer is saturated. Drops must be counted, never block, never spawn a
|
||||
// goroutine.
|
||||
func TestHybridBackend_L1BackfillFullDrops(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
// Tiny buffer + slow primary writes via failSet so the worker stays
|
||||
// blocked enough to overflow the buffer.
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 4,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
// Stop the worker from draining: cancel the underlying context so the
|
||||
// worker bails out, leaving us with a cold buffer and the queue method
|
||||
// itself responsible for drop accounting.
|
||||
hybrid.cancel()
|
||||
// Wait for worker to exit so it can't drain.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)
|
||||
}
|
||||
|
||||
assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0),
|
||||
"expected some drops when buffer is saturated and worker is stopped")
|
||||
}
|
||||
Vendored
+224
-197
@@ -2,20 +2,27 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Default configuration values
|
||||
const (
|
||||
defaultShardCount = 256
|
||||
defaultMaxSize = int64(10000)
|
||||
defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB
|
||||
defaultCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
value interface{}
|
||||
element *list.Element
|
||||
element interface{} // *list.Element, using interface{} to avoid import cycle
|
||||
key string
|
||||
accessCount int64
|
||||
size int64
|
||||
@@ -29,56 +36,89 @@ func (item *memoryCacheItem) isExpired() bool {
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage
|
||||
// The sharded design reduces lock contention by partitioning keys across multiple shards,
|
||||
// each with its own lock.
|
||||
type MemoryCacheBackend struct {
|
||||
shards []*cacheShard
|
||||
startTime time.Time
|
||||
lastErrorTime time.Time
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
cleanupDone chan bool
|
||||
cleanupDone chan struct{}
|
||||
cleanupTicker *time.Ticker
|
||||
evictionPolicy string
|
||||
lastError string
|
||||
currentMemory int64
|
||||
misses atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
sets atomic.Int64
|
||||
hits atomic.Int64
|
||||
shardCount uint32
|
||||
shardMask uint32
|
||||
maxSize int64
|
||||
currentSize int64
|
||||
maxMemory int64
|
||||
cleanupInterval time.Duration
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool
|
||||
|
||||
// Global stats (aggregated from shards)
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// State
|
||||
closed atomic.Bool
|
||||
mu sync.RWMutex // For global operations like stats and error tracking
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
// NewMemoryCacheBackend creates a new sharded memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 10000 // Default to 10k items
|
||||
maxSize = defaultMaxSize
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = 100 * 1024 * 1024 // Default to 100MB
|
||||
maxMemory = defaultMaxMemory
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = 5 * time.Minute
|
||||
cleanupInterval = defaultCleanupInterval
|
||||
}
|
||||
|
||||
shardCount := uint32(defaultShardCount)
|
||||
|
||||
// For very small caches, reduce shard count to maintain sensible per-shard limits
|
||||
// Ensure each shard can hold at least 2 items for proper LRU behavior
|
||||
for shardCount > 1 && maxSize/int64(shardCount) < 2 {
|
||||
shardCount /= 2
|
||||
}
|
||||
if shardCount < 1 {
|
||||
shardCount = 1
|
||||
}
|
||||
|
||||
// Per-shard limits are soft hints; global limits are enforced
|
||||
// Give shards 2x the average to allow for uneven distribution
|
||||
shardMaxSize := (maxSize * 2) / int64(shardCount)
|
||||
if shardMaxSize < 4 {
|
||||
shardMaxSize = 4
|
||||
}
|
||||
shardMaxMemory := (maxMemory * 2) / int64(shardCount)
|
||||
if shardMaxMemory < 4096 {
|
||||
shardMaxMemory = 4096 // Minimum 4KB per shard
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
shards: make([]*cacheShard, shardCount),
|
||||
shardCount: shardCount,
|
||||
shardMask: shardCount - 1, // For fast modulo with power-of-2
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
evictionPolicy: "lru",
|
||||
cleanupDone: make(chan bool),
|
||||
cleanupDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize shards
|
||||
for i := uint32(0); i < shardCount; i++ {
|
||||
m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
@@ -88,6 +128,12 @@ func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.
|
||||
return m
|
||||
}
|
||||
|
||||
// getShard returns the shard for a given key
|
||||
func (m *MemoryCacheBackend) getShard(key string) *cacheShard {
|
||||
hash := fnv32(key)
|
||||
return m.shards[hash&m.shardMask]
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
@@ -100,20 +146,19 @@ func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from the cache
|
||||
// cleanupExpired removes all expired items from all shards
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var keysToDelete []string
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range keysToDelete {
|
||||
m.deleteItemLocked(key)
|
||||
totalRemoved := 0
|
||||
for _, shard := range m.shards {
|
||||
totalRemoved += shard.cleanup()
|
||||
}
|
||||
|
||||
if totalRemoved > 0 {
|
||||
m.evictions.Add(int64(totalRemoved))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,35 +175,23 @@ func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{},
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
shard := m.getShard(key)
|
||||
value, exists, expired := shard.get(key)
|
||||
|
||||
if expired {
|
||||
// Clean up expired item
|
||||
shard.delete(key)
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.Lock()
|
||||
m.deleteItemLocked(key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
// Update access time and count
|
||||
m.mu.Lock()
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
// Move to front of LRU list
|
||||
if m.evictionPolicy == "lru" && item.element != nil {
|
||||
m.lruList.MoveToFront(item.element)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return item.value, nil
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
@@ -174,113 +207,105 @@ func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interfac
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size (simplified estimation)
|
||||
// Calculate item size
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Enforce global limits before adding new item
|
||||
m.enforceGlobalLimits(itemSize)
|
||||
|
||||
// Check if we need to evict items
|
||||
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
|
||||
m.evictLocked()
|
||||
}
|
||||
|
||||
// Check if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
m.currentMemory -= oldItem.size
|
||||
if oldItem.element != nil {
|
||||
m.lruList.Remove(oldItem.element)
|
||||
}
|
||||
} else {
|
||||
m.currentSize++
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = now.Add(ttl)
|
||||
expiresAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: itemSize,
|
||||
}
|
||||
shard := m.getShard(key)
|
||||
shard.set(key, value, expiresAt, itemSize)
|
||||
|
||||
// Add to LRU list
|
||||
if m.evictionPolicy == "lru" {
|
||||
item.element = m.lruList.PushFront(item)
|
||||
}
|
||||
|
||||
m.items[key] = item
|
||||
m.currentMemory += itemSize
|
||||
m.sets.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// enforceGlobalLimits ensures global size and memory limits are respected
|
||||
// by evicting from shards when necessary
|
||||
func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) {
|
||||
// Check and enforce size limit
|
||||
for {
|
||||
totalSize, totalMemory := m.getGlobalStats()
|
||||
|
||||
needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize
|
||||
needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory
|
||||
|
||||
if !needsSizeEviction && !needsMemoryEviction {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the shard with the most items and evict from it
|
||||
evicted := m.evictFromLargestShard()
|
||||
if !evicted {
|
||||
break // No more items to evict
|
||||
}
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// getGlobalStats returns the total size and memory usage across all shards
|
||||
func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) {
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// evictFromLargestShard evicts the globally oldest item across all shards
|
||||
// This provides true LRU behavior even with sharding
|
||||
func (m *MemoryCacheBackend) evictFromLargestShard() bool {
|
||||
var oldestShard *cacheShard
|
||||
var oldestTime time.Time
|
||||
|
||||
for _, shard := range m.shards {
|
||||
accessTime := shard.getOldestAccessTime()
|
||||
// Skip empty shards
|
||||
if accessTime.IsZero() {
|
||||
continue
|
||||
}
|
||||
// Find the shard with the oldest (earliest) access time
|
||||
if oldestShard == nil || accessTime.Before(oldestTime) {
|
||||
oldestTime = accessTime
|
||||
oldestShard = shard
|
||||
}
|
||||
}
|
||||
|
||||
if oldestShard == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return oldestShard.evictOne()
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.items[key]; !exists {
|
||||
return nil
|
||||
shard := m.getShard(key)
|
||||
if shard.delete(key) {
|
||||
m.deletes.Add(1)
|
||||
}
|
||||
|
||||
m.deleteItemLocked(key)
|
||||
m.deletes.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
|
||||
if item, exists := m.items[key]; exists {
|
||||
m.currentMemory -= item.size
|
||||
m.currentSize--
|
||||
if item.element != nil {
|
||||
m.lruList.Remove(item.element)
|
||||
}
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// evictLocked evicts items based on the eviction policy (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) evictLocked() {
|
||||
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
|
||||
// Evict least recently used item
|
||||
element := m.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
m.deleteItemLocked(item.key)
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !item.isExpired(), nil
|
||||
shard := m.getShard(key)
|
||||
return shard.exists(key), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
@@ -289,13 +314,9 @@ func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryCacheItem)
|
||||
m.lruList = list.New()
|
||||
m.currentSize = 0
|
||||
m.currentMemory = 0
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -306,29 +327,28 @@ func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range m.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
var allKeys []string
|
||||
for _, shard := range m.shards {
|
||||
keys := shard.keys(pattern)
|
||||
allKeys = append(allKeys, keys...)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
return allKeys, nil
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
// Size returns the total number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
var total int64
|
||||
for _, shard := range m.shards {
|
||||
size, _ := shard.stats()
|
||||
total += size
|
||||
}
|
||||
|
||||
return m.currentSize, nil
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
@@ -337,24 +357,13 @@ func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
shard := m.getShard(key)
|
||||
ttl, exists := shard.ttl(key)
|
||||
if !exists {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, nil // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
@@ -363,20 +372,11 @@ func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Du
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
shard := m.getShard(key)
|
||||
if !shard.expire(key, ttl) {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -386,6 +386,14 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
// Aggregate stats from all shards
|
||||
var totalSize, totalMemory int64
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
@@ -409,9 +417,9 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: m.currentSize,
|
||||
CurrentSize: totalSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: m.currentMemory,
|
||||
MemoryUsage: totalMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
@@ -438,10 +446,10 @@ func (m *MemoryCacheBackend) Close() error {
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
m.mu.Lock()
|
||||
m.items = nil
|
||||
m.lruList = nil
|
||||
m.mu.Unlock()
|
||||
// Clear all shards
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -474,12 +482,28 @@ func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
// GetShardCount returns the number of shards (for testing/monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardCount() uint32 {
|
||||
return m.shardCount
|
||||
}
|
||||
|
||||
// GetShardStats returns per-shard statistics (for monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 {
|
||||
stats := make([]map[string]int64, m.shardCount)
|
||||
for i, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
stats[i] = map[string]int64{
|
||||
"size": size,
|
||||
"memory": memory,
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
// This is a simplified estimation
|
||||
// In production, you might want to use a more accurate method
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
@@ -502,7 +526,10 @@ func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching - in production, use a proper glob library
|
||||
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
|
||||
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
|
||||
// Simplified pattern matching
|
||||
if len(pattern) > 0 && pattern[0] == '*' {
|
||||
suffix := pattern[1:]
|
||||
return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix
|
||||
}
|
||||
return key == pattern
|
||||
}
|
||||
|
||||
+294
@@ -0,0 +1,294 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cacheShard represents a single shard of the sharded cache
|
||||
// Each shard has its own lock for reduced contention
|
||||
type cacheShard struct {
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
mu sync.RWMutex
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
size int64
|
||||
memoryUsed int64
|
||||
}
|
||||
|
||||
// newCacheShard creates a new cache shard
|
||||
func newCacheShard(maxSize, maxMemory int64) *cacheShard {
|
||||
return &cacheShard{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a value from this shard
|
||||
// Returns: value, exists, expired
|
||||
func (s *cacheShard) get(key string) (interface{}, bool, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
return nil, true, true // exists but expired
|
||||
}
|
||||
|
||||
// Update access time and LRU position under write lock
|
||||
s.mu.Lock()
|
||||
// Re-check item exists (could have been deleted)
|
||||
item, exists = s.items[key]
|
||||
if exists && !item.isExpired() {
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.MoveToFront(elem)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
return item.value, true, false
|
||||
}
|
||||
|
||||
// set stores a value in this shard
|
||||
func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if s.maxSize > 0 && s.size >= s.maxSize {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
|
||||
// Remove old item if exists
|
||||
if oldItem, exists := s.items[key]; exists {
|
||||
s.memoryUsed -= oldItem.size
|
||||
if elem, ok := oldItem.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
s.size--
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: size,
|
||||
}
|
||||
|
||||
item.element = s.lruList.PushFront(item)
|
||||
s.items[key] = item
|
||||
s.size++
|
||||
s.memoryUsed += size
|
||||
}
|
||||
|
||||
// delete removes a key from this shard
|
||||
// Returns true if the key was deleted
|
||||
func (s *cacheShard) delete(key string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
|
||||
// exists checks if a key exists (and is not expired)
|
||||
func (s *cacheShard) exists(key string) bool {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return !item.isExpired()
|
||||
}
|
||||
|
||||
// ttl returns the remaining TTL for a key
|
||||
func (s *cacheShard) ttl(key string) (time.Duration, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, true // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return remaining, true
|
||||
}
|
||||
|
||||
// expire updates the TTL for an existing key
|
||||
func (s *cacheShard) expire(key string, ttl time.Duration) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return false
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// keys returns all non-expired keys matching the pattern
|
||||
func (s *cacheShard) keys(pattern string) []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range s.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// clear removes all items from this shard
|
||||
func (s *cacheShard) clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.items = make(map[string]*memoryCacheItem)
|
||||
s.lruList.Init()
|
||||
s.size = 0
|
||||
s.memoryUsed = 0
|
||||
}
|
||||
|
||||
// cleanup removes expired items
|
||||
// Returns the number of items removed
|
||||
func (s *cacheShard) cleanup() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var toRemove []*memoryCacheItem
|
||||
for _, item := range s.items {
|
||||
if item.isExpired() {
|
||||
toRemove = append(toRemove, item)
|
||||
}
|
||||
}
|
||||
|
||||
for _, item := range toRemove {
|
||||
s.deleteItemLocked(item)
|
||||
}
|
||||
|
||||
return len(toRemove)
|
||||
}
|
||||
|
||||
// stats returns statistics for this shard
|
||||
func (s *cacheShard) stats() (size, memory int64) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.size, s.memoryUsed
|
||||
}
|
||||
|
||||
// deleteItemLocked removes an item (must be called with lock held)
|
||||
func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) {
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
delete(s.items, item.key)
|
||||
s.size--
|
||||
s.memoryUsed -= item.size
|
||||
}
|
||||
|
||||
// evictLRULocked evicts the least recently used item (must be called with lock held)
|
||||
func (s *cacheShard) evictLRULocked() bool {
|
||||
if s.lruList.Len() == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item, ok := element.Value.(*memoryCacheItem)
|
||||
if ok {
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// evictOne evicts one item from this shard (for global limit enforcement)
|
||||
func (s *cacheShard) evictOne() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.evictLRULocked()
|
||||
}
|
||||
|
||||
// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard
|
||||
// Returns zero time if shard is empty
|
||||
func (s *cacheShard) getOldestAccessTime() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.lruList.Len() == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item, ok := element.Value.(*memoryCacheItem)
|
||||
if ok {
|
||||
return item.accessedAt
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// fnv32 computes FNV-1a hash of a string
|
||||
// This is a fast, well-distributed hash function
|
||||
func fnv32(key string) uint32 {
|
||||
const (
|
||||
offset32 = uint32(2166136261)
|
||||
prime32 = uint32(16777619)
|
||||
)
|
||||
|
||||
hash := offset32
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash ^= uint32(key[i])
|
||||
hash *= prime32
|
||||
}
|
||||
return hash
|
||||
}
|
||||
+283
@@ -0,0 +1,283 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestShardedCache_ShardDistribution tests that keys are distributed across shards
|
||||
func TestShardedCache_ShardDistribution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a cache with large enough size to have multiple shards
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add many items to see distribution
|
||||
numItems := 1000
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("dist-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("dist-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Check that items are distributed across multiple shards
|
||||
shardStats := backend.MemoryCacheBackend.GetShardStats()
|
||||
nonEmptyShards := 0
|
||||
for _, stat := range shardStats {
|
||||
if stat["size"] > 0 {
|
||||
nonEmptyShards++
|
||||
}
|
||||
}
|
||||
|
||||
// With good hash distribution, we should have items in multiple shards
|
||||
assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards")
|
||||
}
|
||||
|
||||
// TestShardedCache_ShardCount tests that shard count adapts to cache size
|
||||
func TestShardedCache_ShardCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
maxSize int
|
||||
expectLowShards bool
|
||||
}{
|
||||
{5, true}, // Very small cache should have fewer shards
|
||||
{100, true}, // Small cache should have fewer shards
|
||||
{10000, false}, // Large cache should have default shards
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = tt.maxSize
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
shardCount := backend.MemoryCacheBackend.GetShardCount()
|
||||
|
||||
if tt.expectLowShards {
|
||||
assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards")
|
||||
} else {
|
||||
assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key
|
||||
func TestShardedCache_ConcurrentSameKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "concurrent-same-key"
|
||||
initialValue := []byte("initial-value")
|
||||
|
||||
err = backend.Set(ctx, key, initialValue, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 50
|
||||
iterations := 100
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of reads and writes
|
||||
if j%3 == 0 {
|
||||
newValue := []byte(fmt.Sprintf("value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, newValue, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Key should still exist
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained
|
||||
func TestShardedCache_GlobalLRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a small cache to force eviction
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
// Small delay to ensure different access times
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Access some items to make them recently used
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Add more items to trigger eviction
|
||||
for i := 10; i < 15; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Recently accessed items (5-9) should still exist
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item %d should exist", i)
|
||||
}
|
||||
|
||||
// Check eviction stats
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestShardedCache_StatsAggregation tests that stats are aggregated correctly
|
||||
func TestShardedCache_StatsAggregation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items to multiple shards
|
||||
numItems := 100
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("stats-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Read some items
|
||||
for i := 0; i < numItems/2; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Read non-existent items
|
||||
for i := 0; i < 10; i++ {
|
||||
backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i))
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
|
||||
// Verify stats
|
||||
assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match")
|
||||
assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match")
|
||||
assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match")
|
||||
assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match")
|
||||
|
||||
// Verify hit rate
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
expectedHitRate := float64(numItems/2) / float64(numItems/2+10)
|
||||
assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match")
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_Parallel benchmarks parallel access
|
||||
func BenchmarkShardedCache_Parallel(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10000; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("bench-key-%d", i%10000)
|
||||
backend.Get(ctx, key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_MixedOps benchmarks mixed operations
|
||||
func BenchmarkShardedCache_MixedOps(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("mixed-key-%d", i%1000)
|
||||
if i%3 == 0 {
|
||||
value := []byte(fmt.Sprintf("mixed-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
} else {
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
+20
-30
@@ -45,21 +45,11 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get the item directly to check TTL
|
||||
m.MemoryCacheBackend.mu.RLock()
|
||||
item, exists := m.MemoryCacheBackend.items[key]
|
||||
m.MemoryCacheBackend.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
if !item.expiresAt.IsZero() {
|
||||
ttl = time.Until(item.expiresAt)
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
// Get TTL using the TTL method
|
||||
ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key)
|
||||
if ttlErr != nil {
|
||||
// If we can't get TTL, still return the value with 0 TTL
|
||||
ttl = 0
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
@@ -68,8 +58,7 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, we might need to handle other types
|
||||
// For now, we'll just return an error
|
||||
// If it's not already []byte, return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
@@ -123,19 +112,20 @@ func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"shard_count": m.MemoryCacheBackend.GetShardCount(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+112
-13
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
TLSServerName: config.TLSServerName,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
EnableTLS: config.EnableTLS,
|
||||
TLSSkipVerify: config.TLSSkipVerify,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
@@ -345,7 +348,7 @@ func (r *RedisBackend) prefixKey(key string) string {
|
||||
|
||||
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
|
||||
// It checks context cancellation at multiple points to ensure fast abort when the
|
||||
// caller's context is cancelled (e.g., due to request timeout).
|
||||
// caller's context is canceled (e.g., due to request timeout).
|
||||
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
|
||||
maxRetries := 3
|
||||
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
|
||||
@@ -377,7 +380,7 @@ func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*Red
|
||||
err = operation(conn)
|
||||
r.pool.Put(conn)
|
||||
|
||||
// Check context after operation - if cancelled, don't bother retrying
|
||||
// Check context after operation - if canceled, don't bother retrying
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -431,39 +434,135 @@ func isRetryableError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetMany stores multiple values in Redis (batch operation)
|
||||
// SetMany stores multiple values in Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially (can be optimized with pipelining later)
|
||||
for key, value := range items {
|
||||
if err := r.Set(ctx, key, value, ttl); err != nil {
|
||||
return err
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For single items, use regular Set
|
||||
if len(items) == 1 {
|
||||
for key, value := range items {
|
||||
return r.Set(ctx, key, value, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all SET commands
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
|
||||
for key, value := range items {
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
if ttl > 0 {
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs
|
||||
pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs
|
||||
pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
pipeline.Queue("SET", prefixedKey, string(value))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pipeline SetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Check responses for errors (each should be "OK")
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
continue
|
||||
}
|
||||
if str, ok := resp.(string); ok && str == "OK" {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values from Redis
|
||||
// GetMany retrieves multiple values from Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
result := make(map[string][]byte)
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially
|
||||
for _, key := range keys {
|
||||
value, _, exists, err := r.Get(ctx, key)
|
||||
// For single key, use regular Get
|
||||
if len(keys) == 1 {
|
||||
result := make(map[string][]byte)
|
||||
value, _, exists, err := r.Get(ctx, keys[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
result[key] = value
|
||||
result[keys[0]] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all GET commands
|
||||
prefixedKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
prefixedKeys[i] = r.prefixKey(key)
|
||||
pipeline.Queue("GET", prefixedKeys[i])
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pipeline GetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Process responses
|
||||
result := make(map[string][]byte)
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
// Key doesn't exist
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
// Invalid response, skip this key
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
result[keys[i]] = []byte(value)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
+461
@@ -0,0 +1,461 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setupTestRedis creates a miniredis instance for testing
|
||||
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "test:",
|
||||
PoolSize: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
backend.Close()
|
||||
})
|
||||
|
||||
return mr, backend
|
||||
}
|
||||
|
||||
// TestPipeline_Basic tests basic pipeline functionality
|
||||
func TestPipeline_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.Addr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SingleCommand", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "single-key", "single-value")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
})
|
||||
|
||||
t.Run("MultipleCommands", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "key1", "value1")
|
||||
pipeline.Queue("SET", "key2", "value2")
|
||||
pipeline.Queue("SET", "key3", "value3")
|
||||
pipeline.Queue("GET", "key1")
|
||||
pipeline.Queue("GET", "key2")
|
||||
pipeline.Queue("GET", "key3")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 6)
|
||||
|
||||
// First 3 are SET responses
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
assert.Equal(t, "OK", responses[1])
|
||||
assert.Equal(t, "OK", responses[2])
|
||||
|
||||
// Last 3 are GET responses
|
||||
assert.Equal(t, "value1", responses[3])
|
||||
assert.Equal(t, "value2", responses[4])
|
||||
assert.Equal(t, "value3", responses[5])
|
||||
})
|
||||
|
||||
t.Run("EmptyPipeline", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, responses)
|
||||
})
|
||||
|
||||
t.Run("NilResponses", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("GET", "nonexistent-key")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Nil(t, responses[0])
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_SetMany tests pipelined SetMany
|
||||
func TestPipeline_SetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetManyItems", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key %s should exist", key)
|
||||
assert.Equal(t, expectedValue, value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetManyEmpty", func(t *testing.T) {
|
||||
err := backend.SetMany(ctx, map[string][]byte{}, time.Minute)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetManySingleItem", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"single-setmany": []byte("single-value"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
value, _, exists, err := backend.Get(ctx, "single-setmany")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("single-value"), value)
|
||||
})
|
||||
|
||||
t.Run("SetManyNoTTL", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"nottl-key1": []byte("value1"),
|
||||
"nottl-key2": []byte("value2"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Keys should exist
|
||||
for key := range items {
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_GetMany tests pipelined GetMany
|
||||
func TestPipeline_GetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("getmany-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("GetManyExisting", func(t *testing.T) {
|
||||
keys := make([]string, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
keys[i] = fmt.Sprintf("getmany-key-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 10)
|
||||
|
||||
for i, key := range keys {
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyMixed", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"getmany-key-0", // exists
|
||||
"nonexistent-key-1", // doesn't exist
|
||||
"getmany-key-2", // exists
|
||||
"nonexistent-key-2", // doesn't exist
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
|
||||
assert.Equal(t, []byte("value-0"), results["getmany-key-0"])
|
||||
assert.Equal(t, []byte("value-2"), results["getmany-key-2"])
|
||||
assert.NotContains(t, results, "nonexistent-key-1")
|
||||
assert.NotContains(t, results, "nonexistent-key-2")
|
||||
})
|
||||
|
||||
t.Run("GetManyEmpty", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, results)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
|
||||
t.Run("GetManySingleKey", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{"getmany-key-5"})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, []byte("value-5"), results["getmany-key-5"])
|
||||
})
|
||||
|
||||
t.Run("GetManyAllNonexistent", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"nonexistent-1",
|
||||
"nonexistent-2",
|
||||
"nonexistent-3",
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_LargeBatch tests pipelining with large batches
|
||||
func TestPipeline_LargeBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany100Items", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify random samples
|
||||
for _, i := range []int{0, 25, 50, 75, 99} {
|
||||
key := fmt.Sprintf("large-batch-%d", i)
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany100Items", func(t *testing.T) {
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("large-batch-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 100)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_Stats tests that stats are tracked correctly with pipelining
|
||||
func TestPipeline_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Set some items
|
||||
items := map[string][]byte{
|
||||
"stats-key-1": []byte("value1"),
|
||||
"stats-key-2": []byte("value2"),
|
||||
}
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get items (some exist, some don't)
|
||||
keys := []string{
|
||||
"stats-key-1",
|
||||
"stats-key-2",
|
||||
"stats-key-nonexistent",
|
||||
}
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
// Check stats
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
|
||||
assert.Equal(t, int64(2), hits, "Should have 2 hits")
|
||||
assert.Equal(t, int64(1), misses, "Should have 1 miss")
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining
|
||||
func BenchmarkPipeline_SetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining
|
||||
func BenchmarkPipeline_GetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
// Prepare keys
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("bench-key-%d", i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations
|
||||
func BenchmarkPipeline_VsSequential(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
keys := make([]string, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
key := fmt.Sprintf("compare-key-%d", i)
|
||||
keys[i] = key
|
||||
items[key] = []byte(fmt.Sprintf("compare-value-%d", i))
|
||||
}
|
||||
|
||||
b.Run("Pipelined-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for key, value := range items {
|
||||
_ = backend.Set(ctx, key, value, time.Minute)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Pre-populate for get benchmarks
|
||||
_ = backend.SetMany(ctx, items, time.Hour)
|
||||
|
||||
b.Run("Pipelined-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, key := range keys {
|
||||
_, _, _, _ = backend.Get(ctx, key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
+142
-3
@@ -2,6 +2,7 @@ package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -31,6 +32,7 @@ type ConnectionPool struct {
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
TLSServerName string // SNI server name; defaults to host(Address) when empty
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
@@ -39,6 +41,8 @@ type PoolConfig struct {
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
|
||||
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
conn, err = p.createConnection(ctx)
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if p.config.EnableTLS {
|
||||
serverName := p.config.TLSServerName
|
||||
if serverName == "" {
|
||||
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
|
||||
serverName = host
|
||||
}
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
|
||||
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
} else {
|
||||
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
@@ -336,3 +358,120 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
_, err := conn.Do("PING")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Pipeline represents a Redis pipeline for batch operations
|
||||
// It queues multiple commands and executes them in a single round-trip
|
||||
type Pipeline struct {
|
||||
conn *RedisConn
|
||||
commands []pipelineCommand
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// pipelineCommand represents a single command in the pipeline
|
||||
type pipelineCommand struct {
|
||||
command string
|
||||
args []string
|
||||
}
|
||||
|
||||
// NewPipeline creates a new pipeline for the connection
|
||||
func (c *RedisConn) NewPipeline() *Pipeline {
|
||||
return &Pipeline{
|
||||
conn: c,
|
||||
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
|
||||
}
|
||||
}
|
||||
|
||||
// Queue adds a command to the pipeline
|
||||
func (p *Pipeline) Queue(command string, args ...string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.commands = append(p.commands, pipelineCommand{
|
||||
command: command,
|
||||
args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute sends all queued commands and returns all responses
|
||||
// Returns a slice of responses in the same order as commands were queued
|
||||
func (p *Pipeline) Execute() ([]interface{}, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.commands) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.conn.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.conn.mu.Lock()
|
||||
defer p.conn.mu.Unlock()
|
||||
|
||||
// Set write timeout for all commands
|
||||
if p.conn.writeTimeout > 0 {
|
||||
// Use longer timeout for batch operations
|
||||
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second // Cap at 30 seconds
|
||||
}
|
||||
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Write all commands (pipelining - send all before reading any responses)
|
||||
writer := NewRESPWriter(p.conn.conn)
|
||||
for _, cmd := range p.commands {
|
||||
cmdArgs := append([]string{cmd.command}, cmd.args...)
|
||||
if err := writer.WriteCommand(cmdArgs...); err != nil {
|
||||
writer.Release()
|
||||
p.conn.closed.Store(true)
|
||||
return nil, fmt.Errorf("pipeline write error: %w", err)
|
||||
}
|
||||
}
|
||||
writer.Release()
|
||||
|
||||
// Set read timeout for all responses
|
||||
if p.conn.readTimeout > 0 {
|
||||
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Read all responses
|
||||
responses := make([]interface{}, len(p.commands))
|
||||
reader := NewRESPReader(p.conn.conn)
|
||||
defer reader.Release()
|
||||
|
||||
for i := range p.commands {
|
||||
resp, err := reader.ReadResponse()
|
||||
if err != nil {
|
||||
// For nil responses, store nil instead of erroring
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
responses[i] = nil
|
||||
continue
|
||||
}
|
||||
p.conn.closed.Store(true)
|
||||
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
|
||||
}
|
||||
responses[i] = resp
|
||||
}
|
||||
|
||||
return responses, nil
|
||||
}
|
||||
|
||||
// Clear resets the pipeline for reuse
|
||||
func (p *Pipeline) Clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.commands = p.commands[:0]
|
||||
}
|
||||
|
||||
// Len returns the number of queued commands
|
||||
func (p *Pipeline) Len() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.commands)
|
||||
}
|
||||
|
||||
+31
-1
@@ -3,6 +3,7 @@ package backends
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -201,7 +202,7 @@ func TestConnectionPool_ContextCancellation(t *testing.T) {
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get another with cancelled context
|
||||
// Try to get another with canceled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
@@ -617,4 +618,33 @@ func TestRedisConn_TooManyArguments(t *testing.T) {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// TestRedisConn_RejectOversizedArgumentBytes is a regression test for CodeQL
|
||||
// alert #10 (go/allocation-size-overflow). A single argument larger than
|
||||
// maxTotalArgBytes (64 MiB) must be rejected by the per-argument overflow
|
||||
// guard in Do() before any allocation is attempted.
|
||||
func TestRedisConn_RejectOversizedArgumentBytes(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
largeArg := strings.Repeat("x", (64<<20)+1)
|
||||
|
||||
_, err = conn.Do("SET", "k", largeArg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "arguments too large")
|
||||
}
|
||||
|
||||
+230
@@ -0,0 +1,230 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// drainRESPRequest consumes a single RESP request (array or inline) from r and
|
||||
// returns true on success. Any read error returns false.
|
||||
func drainRESPRequest(r *bufio.Reader) bool {
|
||||
header, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(header, "*") {
|
||||
return true // inline command (single line) — already consumed
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
|
||||
if err != nil || n <= 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
// Each bulk: "$len\r\n<bytes>\r\n"
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
|
||||
// answer PING with +PONG. Returns the listener address and a self-signed cert.
|
||||
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
|
||||
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
c, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
defer conn.Close()
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
if !drainRESPRequest(reader) {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("+PONG\r\n"))
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
stop = func() {
|
||||
close(stopCh)
|
||||
_ = listener.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
return listener.Addr().String(), der, stop
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
|
||||
// Regression test for issue #133 (enableTLS not propagated to client).
|
||||
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=false rejects a self-signed server cert.
|
||||
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "tls")
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
|
||||
// fails to handshake against a plain (non-TLS) listener.
|
||||
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
|
||||
plain, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer plain.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, acceptErr := plain.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: plain.Addr().String(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
|
||||
// when EnableTLS=false (default).
|
||||
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
Vendored
+15
-34
@@ -7,52 +7,34 @@ import (
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RESP (REdis Serialization Protocol) implementation
|
||||
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
|
||||
//
|
||||
// NOTE: sync.Pool was intentionally removed for Yaegi compatibility.
|
||||
// Yaegi (Traefik's Go interpreter) has issues with sync.Pool and reflection
|
||||
// that cause "reflect: call of reflect.Value.Field on zero Value" panics.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/120
|
||||
|
||||
var (
|
||||
ErrInvalidRESP = errors.New("invalid RESP response")
|
||||
ErrNilResponse = errors.New("nil response")
|
||||
)
|
||||
|
||||
// Object pools for memory optimization - reduces allocations by 50-70%
|
||||
var (
|
||||
readerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPReader{
|
||||
r: bufio.NewReaderSize(nil, 4096),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
writerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPWriter{
|
||||
w: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// RESPWriter writes RESP protocol messages
|
||||
type RESPWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
|
||||
// NewRESPWriter creates a new RESP writer
|
||||
func NewRESPWriter(w io.Writer) *RESPWriter {
|
||||
writer := writerPool.Get().(*RESPWriter)
|
||||
writer.w = w
|
||||
return writer
|
||||
return &RESPWriter{w: w}
|
||||
}
|
||||
|
||||
// Release returns the writer to the pool for reuse
|
||||
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
|
||||
func (w *RESPWriter) Release() {
|
||||
w.w = nil
|
||||
writerPool.Put(w)
|
||||
// No-op: pooling removed for Yaegi compatibility
|
||||
}
|
||||
|
||||
// WriteCommand writes a Redis command in RESP array format
|
||||
@@ -78,17 +60,16 @@ type RESPReader struct {
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
|
||||
// NewRESPReader creates a new RESP reader
|
||||
func NewRESPReader(r io.Reader) *RESPReader {
|
||||
reader := readerPool.Get().(*RESPReader)
|
||||
reader.r.Reset(r)
|
||||
return reader
|
||||
return &RESPReader{
|
||||
r: bufio.NewReaderSize(r, 4096),
|
||||
}
|
||||
}
|
||||
|
||||
// Release returns the reader to the pool for reuse
|
||||
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
|
||||
func (r *RESPReader) Release() {
|
||||
r.r.Reset(nil)
|
||||
readerPool.Put(r)
|
||||
// No-op: pooling removed for Yaegi compatibility
|
||||
}
|
||||
|
||||
// ReadResponse reads a RESP response and returns the parsed value
|
||||
|
||||
+183
@@ -0,0 +1,183 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SingleflightCache wraps a CacheBackend with singleflight deduplication
|
||||
// to prevent thundering herd problems when multiple concurrent requests
|
||||
// try to fetch the same uncached key.
|
||||
type SingleflightCache struct {
|
||||
backend CacheBackend
|
||||
mu sync.Mutex
|
||||
calls map[string]*singleflightCall
|
||||
|
||||
// Metrics
|
||||
deduplicatedCalls atomic.Int64
|
||||
totalCalls atomic.Int64
|
||||
}
|
||||
|
||||
// singleflightCall represents an in-flight or completed fetch call
|
||||
type singleflightCall struct {
|
||||
wg sync.WaitGroup
|
||||
val []byte
|
||||
ttl time.Duration
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
// NewSingleflightCache creates a new singleflight-wrapped cache backend
|
||||
func NewSingleflightCache(backend CacheBackend) *SingleflightCache {
|
||||
return &SingleflightCache{
|
||||
backend: backend,
|
||||
calls: make(map[string]*singleflightCall),
|
||||
}
|
||||
}
|
||||
|
||||
// Fetcher is a function type that fetches data when cache misses
|
||||
type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error)
|
||||
|
||||
// GetOrFetch retrieves a value from cache or calls the fetcher exactly once
|
||||
// per key when there's a cache miss. Concurrent calls for the same key will
|
||||
// wait for the first call to complete and share its result.
|
||||
func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) {
|
||||
s.totalCalls.Add(1)
|
||||
|
||||
// Try cache first
|
||||
value, _, exists, err := s.backend.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Cache miss - use singleflight
|
||||
s.mu.Lock()
|
||||
|
||||
// Check if there's already an in-flight call for this key
|
||||
if call, ok := s.calls[key]; ok {
|
||||
s.mu.Unlock()
|
||||
s.deduplicatedCalls.Add(1)
|
||||
|
||||
// Wait for the in-flight call to complete
|
||||
call.wg.Wait()
|
||||
|
||||
// Check context cancellation
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Create new call
|
||||
call := &singleflightCall{}
|
||||
call.wg.Add(1)
|
||||
s.calls[key] = call
|
||||
s.mu.Unlock()
|
||||
|
||||
// Execute the fetcher
|
||||
call.val, call.ttl, call.err = fetcher(ctx)
|
||||
call.done = true
|
||||
|
||||
// If successful, store in cache
|
||||
if call.err == nil && call.val != nil {
|
||||
// Use a background context for cache storage to ensure it completes
|
||||
// even if the original context is canceled
|
||||
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Signal waiting goroutines
|
||||
call.wg.Done()
|
||||
|
||||
// Clean up the call from the map after a short delay
|
||||
// This allows late arrivals to still benefit from the result
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
s.mu.Lock()
|
||||
if c, ok := s.calls[key]; ok && c == call {
|
||||
delete(s.calls, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the underlying cache backend
|
||||
func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
return s.backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Set stores a value in the underlying cache backend
|
||||
func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return s.backend.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// Delete removes a key from the underlying cache backend
|
||||
func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the underlying cache backend
|
||||
func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the underlying cache backend
|
||||
func (s *SingleflightCache) Clear(ctx context.Context) error {
|
||||
return s.backend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics including singleflight metrics
|
||||
func (s *SingleflightCache) GetStats() map[string]interface{} {
|
||||
stats := s.backend.GetStats()
|
||||
|
||||
// Add singleflight-specific stats
|
||||
totalCalls := s.totalCalls.Load()
|
||||
deduped := s.deduplicatedCalls.Load()
|
||||
|
||||
stats["singleflight_total_calls"] = totalCalls
|
||||
stats["singleflight_deduplicated"] = deduped
|
||||
if totalCalls > 0 {
|
||||
stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls)
|
||||
} else {
|
||||
stats["singleflight_dedup_rate"] = float64(0)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
stats["singleflight_inflight"] = len(s.calls)
|
||||
s.mu.Unlock()
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend
|
||||
func (s *SingleflightCache) Close() error {
|
||||
return s.backend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (s *SingleflightCache) Ping(ctx context.Context) error {
|
||||
return s.backend.Ping(ctx)
|
||||
}
|
||||
|
||||
// GetBackend returns the underlying cache backend
|
||||
func (s *SingleflightCache) GetBackend() CacheBackend {
|
||||
return s.backend
|
||||
}
|
||||
|
||||
// ResetStats resets the singleflight statistics
|
||||
func (s *SingleflightCache) ResetStats() {
|
||||
s.totalCalls.Store(0)
|
||||
s.deduplicatedCalls.Store(0)
|
||||
}
|
||||
|
||||
// Ensure SingleflightCache implements CacheBackend
|
||||
var _ CacheBackend = (*SingleflightCache)(nil)
|
||||
+510
@@ -0,0 +1,510 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality
|
||||
func TestSingleflightCache_BasicGetOrFetch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CacheHit", func(t *testing.T) {
|
||||
key := "existing-key"
|
||||
value := []byte("existing-value")
|
||||
|
||||
// Pre-populate cache
|
||||
err := cache.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return []byte("fetched-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, value, result)
|
||||
assert.False(t, fetchCalled, "Fetcher should not be called on cache hit")
|
||||
})
|
||||
|
||||
t.Run("CacheMiss", func(t *testing.T) {
|
||||
key := "missing-key"
|
||||
expectedValue := []byte("fetched-value")
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, result)
|
||||
assert.True(t, fetchCalled, "Fetcher should be called on cache miss")
|
||||
|
||||
// Verify value was stored in cache
|
||||
cached, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, cached)
|
||||
})
|
||||
|
||||
t.Run("FetcherError", func(t *testing.T) {
|
||||
key := "error-key"
|
||||
expectedErr := errors.New("fetch failed")
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
assert.Nil(t, result)
|
||||
|
||||
// Verify nothing was stored in cache
|
||||
_, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated
|
||||
func TestSingleflightCache_Deduplication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "dedup-key"
|
||||
expectedValue := []byte("dedup-value")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 10
|
||||
var wg sync.WaitGroup
|
||||
results := make([][]byte, concurrency)
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same result
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.NoError(t, errs[i])
|
||||
assert.Equal(t, expectedValue, results[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once")
|
||||
|
||||
// Verify deduplication stats
|
||||
stats := cache.GetStats()
|
||||
deduped := stats["singleflight_deduplicated"].(int64)
|
||||
assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel
|
||||
func TestSingleflightCache_DifferentKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetchStarted := make(chan struct{}, 3)
|
||||
fetchComplete := make(chan struct{})
|
||||
|
||||
fetcher := func(key string) Fetcher {
|
||||
return func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
fetchStarted <- struct{}{}
|
||||
<-fetchComplete // Wait for signal
|
||||
return []byte("value-" + key), time.Minute, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Launch concurrent requests for different keys
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 3; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("key-%d", idx)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher(key))
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all fetches to start
|
||||
for i := 0; i < 3; i++ {
|
||||
<-fetchStarted
|
||||
}
|
||||
|
||||
// All 3 fetches should be running in parallel
|
||||
assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel")
|
||||
|
||||
// Release all fetches
|
||||
close(fetchComplete)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ContextCancellation tests context cancellation
|
||||
func TestSingleflightCache_ContextCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
key := "cancel-key"
|
||||
fetchStarted := make(chan struct{})
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
close(fetchStarted)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Start first request with long timeout
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx := context.Background()
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}()
|
||||
|
||||
// Wait for fetch to start
|
||||
<-fetchStarted
|
||||
|
||||
// Start second request with short timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated
|
||||
func TestSingleflightCache_ErrorPropagation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "error-prop-key"
|
||||
expectedErr := errors.New("intentional error")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 5
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same error
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.Error(t, errs[i])
|
||||
assert.Equal(t, expectedErr, errs[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load())
|
||||
}
|
||||
|
||||
// TestSingleflightCache_PassthroughMethods tests that passthrough methods work
|
||||
func TestSingleflightCache_PassthroughMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Set", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, _, exists, err := cache.Get(ctx, "set-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("set-value"), val)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, ttl, exists, err := cache.Get(ctx, "get-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("get-value"), val)
|
||||
assert.Greater(t, ttl, time.Duration(0))
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := cache.Delete(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := cache.Exists(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
exists, err := cache.Exists(ctx, "nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = cache.Exists(ctx, "exists-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err := cache.Exists(ctx, "clear-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
err := cache.Ping(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Stats tests statistics tracking
|
||||
func TestSingleflightCache_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Make some calls
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = cache.GetOrFetch(ctx, "stats-key", fetcher)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
stats := cache.GetStats()
|
||||
|
||||
// Check singleflight stats exist
|
||||
assert.Contains(t, stats, "singleflight_total_calls")
|
||||
assert.Contains(t, stats, "singleflight_deduplicated")
|
||||
assert.Contains(t, stats, "singleflight_dedup_rate")
|
||||
assert.Contains(t, stats, "singleflight_inflight")
|
||||
|
||||
// Verify values
|
||||
assert.Equal(t, int64(5), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(4), stats["singleflight_deduplicated"])
|
||||
|
||||
// Also check underlying backend stats are included
|
||||
assert.Contains(t, stats, "hits")
|
||||
assert.Contains(t, stats, "misses")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ResetStats tests stats reset
|
||||
func TestSingleflightCache_ResetStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Make some calls
|
||||
_, _ = cache.GetOrFetch(ctx, "key1", fetcher)
|
||||
_, _ = cache.GetOrFetch(ctx, "key2", fetcher)
|
||||
|
||||
stats := cache.GetStats()
|
||||
assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0))
|
||||
|
||||
// Reset stats
|
||||
cache.ResetStats()
|
||||
|
||||
stats = cache.GetStats()
|
||||
assert.Equal(t, int64(0), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(0), stats["singleflight_deduplicated"])
|
||||
}
|
||||
|
||||
// TestSingleflightCache_GetBackend tests GetBackend method
|
||||
func TestSingleflightCache_GetBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
assert.Equal(t, backend, cache.GetBackend())
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Sequential benchmarks sequential access
|
||||
func BenchmarkSingleflightCache_Sequential(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("key-%d", i%100)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access
|
||||
func BenchmarkSingleflightCache_Concurrent(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(time.Millisecond) // Simulate slow fetch
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario
|
||||
func BenchmarkSingleflightCache_HighContention(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// All goroutines hit the same key
|
||||
_, _ = cache.GetOrFetch(ctx, "hot-key", fetcher)
|
||||
}
|
||||
})
|
||||
}
|
||||
Vendored
+1
-1
@@ -232,7 +232,7 @@ func (m *Manager) Close() error {
|
||||
|
||||
var firstErr error
|
||||
|
||||
if err := m.tokenCache.Close(); err != nil && firstErr == nil {
|
||||
if err := m.tokenCache.Close(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := m.metadataCache.Close(); err != nil && firstErr == nil {
|
||||
|
||||
@@ -397,7 +397,7 @@ func (wp *WorkerPool) Submit(task func()) error {
|
||||
}
|
||||
|
||||
// worker is the main worker routine
|
||||
func (wp *WorkerPool) worker(id int) {
|
||||
func (wp *WorkerPool) worker(_ int) {
|
||||
defer wp.workerWg.Done()
|
||||
|
||||
for {
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FileStore implements Store using file-based storage.
|
||||
// This is the default storage backend for backward compatibility with existing deployments.
|
||||
// For distributed environments, consider using RedisStore instead.
|
||||
type FileStore struct {
|
||||
basePath string
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFileStore creates a new file-based credentials store.
|
||||
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
|
||||
func NewFileStore(basePath string, logger Logger) *FileStore {
|
||||
if basePath == "" {
|
||||
basePath = "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
if logger == nil {
|
||||
logger = NoOpLogger()
|
||||
}
|
||||
return &FileStore{
|
||||
basePath: basePath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// BasePath returns the base path used for storing credentials
|
||||
func (s *FileStore) BasePath() string {
|
||||
return s.basePath
|
||||
}
|
||||
|
||||
// GetFilePath returns the file path for storing credentials for a specific provider.
|
||||
// For multi-tenant scenarios, each provider gets a separate file based on URL hash.
|
||||
func (s *FileStore) GetFilePath(providerURL string) string {
|
||||
if providerURL == "" {
|
||||
return s.basePath
|
||||
}
|
||||
|
||||
// Hash provider URL for filename safety and uniqueness
|
||||
hash := sha256.Sum256([]byte(providerURL))
|
||||
hashStr := hex.EncodeToString(hash[:8]) // Use first 8 bytes for shorter filename
|
||||
|
||||
ext := filepath.Ext(s.basePath)
|
||||
base := strings.TrimSuffix(s.basePath, ext)
|
||||
if ext == "" {
|
||||
ext = ".json"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s%s", base, hashStr, ext)
|
||||
}
|
||||
|
||||
// Save stores the client registration response to a file
|
||||
func (s *FileStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
if creds == nil {
|
||||
return fmt.Errorf("credentials cannot be nil")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
// Ensure parent directory exists
|
||||
dir := filepath.Dir(filePath)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create credentials directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(creds, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Write with restrictive permissions (owner read/write only)
|
||||
if err := os.WriteFile(filePath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Saved client credentials to %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves stored credentials from a file.
|
||||
// Returns nil, nil if no credentials file exists (not an error).
|
||||
func (s *FileStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via GetFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // No credentials file exists - not an error
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read credentials file: %w", err)
|
||||
}
|
||||
|
||||
var creds ClientRegistrationResponse
|
||||
if err := json.Unmarshal(data, &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Loaded client credentials from %s", filePath)
|
||||
return &creds, nil
|
||||
}
|
||||
|
||||
// Delete removes the credentials file for a provider
|
||||
func (s *FileStore) Delete(ctx context.Context, providerURL string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to delete
|
||||
}
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Deleted client credentials from %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
func (s *FileStore) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
_, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check credentials file: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache defines the interface for cache operations needed by RedisStore.
|
||||
// This allows the main package to provide a cache implementation without
|
||||
// creating circular dependencies.
|
||||
type Cache interface {
|
||||
// Get retrieves a value from the cache
|
||||
Get(key string) (any, bool)
|
||||
// Set stores a value in the cache with a TTL
|
||||
Set(key string, value any, ttl time.Duration) error
|
||||
// Delete removes a value from the cache
|
||||
Delete(key string)
|
||||
}
|
||||
|
||||
// RedisStore implements Store using a Cache-backed storage.
|
||||
// This storage backend enables sharing DCR credentials across multiple Traefik instances
|
||||
// in distributed environments (e.g., Kubernetes with multiple ingress pods).
|
||||
type RedisStore struct {
|
||||
cache Cache
|
||||
keyPrefix string
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new cache-backed credentials store.
|
||||
// The cache should be configured with a Redis backend for distributed storage.
|
||||
// If keyPrefix is empty, defaults to "dcr:creds:"
|
||||
func NewRedisStore(cache Cache, keyPrefix string, logger Logger) *RedisStore {
|
||||
if keyPrefix == "" {
|
||||
keyPrefix = "dcr:creds:"
|
||||
}
|
||||
if logger == nil {
|
||||
logger = NoOpLogger()
|
||||
}
|
||||
return &RedisStore{
|
||||
cache: cache,
|
||||
keyPrefix: keyPrefix,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// makeKey creates a unique cache key for a provider URL.
|
||||
// Uses SHA256 hash of the provider URL for consistent key generation across nodes.
|
||||
func (s *RedisStore) makeKey(providerURL string) string {
|
||||
if providerURL == "" {
|
||||
return s.keyPrefix + "default"
|
||||
}
|
||||
hash := sha256.Sum256([]byte(providerURL))
|
||||
return s.keyPrefix + hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Save stores the client registration response in the cache.
|
||||
// TTL is calculated based on client_secret_expires_at if available.
|
||||
func (s *RedisStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
if creds == nil {
|
||||
return fmt.Errorf("credentials cannot be nil")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
|
||||
// Calculate TTL based on client_secret_expires_at if available
|
||||
ttl := 30 * 24 * time.Hour // Default: 30 days
|
||||
if creds.ClientSecretExpiresAt > 0 {
|
||||
expiresAt := time.Unix(creds.ClientSecretExpiresAt, 0)
|
||||
ttl = time.Until(expiresAt)
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("credentials already expired")
|
||||
}
|
||||
// Add a small buffer to ensure we don't serve expired credentials
|
||||
if ttl > time.Minute {
|
||||
ttl -= time.Minute
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize credentials to JSON for storage
|
||||
data, err := json.Marshal(creds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Store as string in cache (will be serialized by the cache backend)
|
||||
if err := s.cache.Set(key, string(data), ttl); err != nil {
|
||||
return fmt.Errorf("failed to store credentials in cache: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Saved client credentials to cache with key %s (TTL: %v)", key, ttl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves stored credentials from the cache.
|
||||
// Returns nil, nil if no credentials exist (not an error).
|
||||
func (s *RedisStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
|
||||
value, exists := s.cache.Get(key)
|
||||
if !exists {
|
||||
return nil, nil // No credentials stored - not an error
|
||||
}
|
||||
|
||||
// Handle different value types from cache
|
||||
var jsonData string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
jsonData = v
|
||||
case []byte:
|
||||
jsonData = string(v)
|
||||
default:
|
||||
// Try to see if it's already the struct (from local cache)
|
||||
if creds, ok := value.(*ClientRegistrationResponse); ok {
|
||||
return creds, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected credentials type in cache: %T", value)
|
||||
}
|
||||
|
||||
var creds ClientRegistrationResponse
|
||||
if err := json.Unmarshal([]byte(jsonData), &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials from cache: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Loaded client credentials from cache with key %s", key)
|
||||
return &creds, nil
|
||||
}
|
||||
|
||||
// Delete removes stored credentials from the cache
|
||||
func (s *RedisStore) Delete(ctx context.Context, providerURL string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
s.cache.Delete(key)
|
||||
|
||||
s.logger.Debugf("Deleted client credentials from cache with key %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if credentials exist in the cache for a provider
|
||||
func (s *RedisStore) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
_, exists := s.cache.Get(key)
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
// Package dcrstorage provides storage backends for OIDC Dynamic Client Registration credentials.
|
||||
// It supports both file-based and Redis-based storage for persisting client credentials
|
||||
// across application restarts and distributed deployments.
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// StorageBackend represents the type of storage backend for DCR credentials
|
||||
type StorageBackend string
|
||||
|
||||
const (
|
||||
// StorageBackendFile uses file-based storage (default for backward compatibility)
|
||||
StorageBackendFile StorageBackend = "file"
|
||||
|
||||
// StorageBackendRedis uses Redis for distributed storage
|
||||
StorageBackendRedis StorageBackend = "redis"
|
||||
|
||||
// StorageBackendAuto automatically selects Redis if available, otherwise file
|
||||
StorageBackendAuto StorageBackend = "auto"
|
||||
)
|
||||
|
||||
// Logger interface for DCR storage operations
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...any)
|
||||
Info(msg string)
|
||||
Infof(format string, args ...any)
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
|
||||
type ClientRegistrationResponse struct {
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
}
|
||||
|
||||
// Store defines the interface for storing DCR credentials.
|
||||
// This abstraction allows different storage backends (file, Redis) to be used
|
||||
// for persisting OIDC Dynamic Client Registration credentials across nodes.
|
||||
type Store interface {
|
||||
// Save stores the client registration response for a provider
|
||||
// The providerURL is used as a key to support multi-tenant scenarios
|
||||
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
|
||||
|
||||
// Load retrieves stored credentials for a provider
|
||||
// Returns nil, nil if no credentials exist (not an error)
|
||||
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
|
||||
|
||||
// Delete removes stored credentials for a provider
|
||||
Delete(ctx context.Context, providerURL string) error
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
Exists(ctx context.Context, providerURL string) (bool, error)
|
||||
}
|
||||
|
||||
// noOpLogger is a no-op implementation of Logger for default use
|
||||
type noOpLogger struct{}
|
||||
|
||||
func (n noOpLogger) Debug(msg string) {}
|
||||
func (n noOpLogger) Debugf(format string, args ...any) {}
|
||||
func (n noOpLogger) Info(msg string) {}
|
||||
func (n noOpLogger) Infof(format string, args ...any) {}
|
||||
func (n noOpLogger) Error(msg string) {}
|
||||
func (n noOpLogger) Errorf(format string, args ...any) {}
|
||||
|
||||
// NoOpLogger returns a no-op logger instance
|
||||
func NoOpLogger() Logger {
|
||||
return noOpLogger{}
|
||||
}
|
||||
@@ -0,0 +1,464 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockCache implements Cache for testing
|
||||
type mockCache struct {
|
||||
data map[string]cacheEntry
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
value any
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newMockCache() *mockCache {
|
||||
return &mockCache{data: make(map[string]cacheEntry)}
|
||||
}
|
||||
|
||||
func (m *mockCache) Get(key string) (any, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
entry, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
return entry.value, true
|
||||
}
|
||||
|
||||
func (m *mockCache) Set(key string, value any, ttl time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = cacheEntry{
|
||||
value: value,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCache) Delete(key string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
}
|
||||
|
||||
func TestFileStore_SaveLoad(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "test-access-token",
|
||||
RegistrationClientURI: "https://example.com/register/test-client-id",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||||
ResponseTypes: []string{"code"},
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
|
||||
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
tempDir2 := t.TempDir()
|
||||
store2 := NewFileStore(filepath.Join(tempDir2, "nonexistent.json"), nil)
|
||||
|
||||
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent file: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
|
||||
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Error("Expected credentials to not exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete non-existent credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete should not error for non-existent: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_MultiProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
provider1 := "https://auth1.example.com"
|
||||
provider2 := "https://auth2.example.com"
|
||||
|
||||
creds1 := &ClientRegistrationResponse{
|
||||
ClientID: "client-1",
|
||||
ClientSecret: "secret-1",
|
||||
}
|
||||
creds2 := &ClientRegistrationResponse{
|
||||
ClientID: "client-2",
|
||||
ClientSecret: "secret-2",
|
||||
}
|
||||
|
||||
if err := store.Save(ctx, provider1, creds1); err != nil {
|
||||
t.Fatalf("Failed to save creds1: %v", err)
|
||||
}
|
||||
if err := store.Save(ctx, provider2, creds2); err != nil {
|
||||
t.Fatalf("Failed to save creds2: %v", err)
|
||||
}
|
||||
|
||||
loaded1, err := store.Load(ctx, provider1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds1: %v", err)
|
||||
}
|
||||
if loaded1.ClientID != "client-1" {
|
||||
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
|
||||
}
|
||||
|
||||
loaded2, err := store.Load(ctx, provider2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds2: %v", err)
|
||||
}
|
||||
if loaded2.ClientID != "client-2" {
|
||||
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
|
||||
}
|
||||
|
||||
if err := store.Delete(ctx, provider1); err != nil {
|
||||
t.Fatalf("Failed to delete creds1: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, provider2)
|
||||
if !exists {
|
||||
t.Error("Provider 2 credentials should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
concurrency := 10
|
||||
|
||||
for range concurrency {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = store.Save(ctx, providerURL, creds)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for range concurrency {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = store.Load(ctx, providerURL)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after concurrent access: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test-client" {
|
||||
t.Error("Credentials corrupted after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty provider URL uses default path", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
err := store.Save(ctx, "", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Save with empty provider URL failed: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Load with empty provider URL failed: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials with empty provider URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_DefaultPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := NewFileStore("", nil)
|
||||
|
||||
if store.BasePath() == "" {
|
||||
t.Error("Expected default base path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisStore_WithMockCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "redis-test-client",
|
||||
ClientSecret: "redis-test-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "redis-test-token",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
}
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisStore_TTLFromExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("expired credentials should fail", func(t *testing.T) {
|
||||
expiredCreds := &ClientRegistrationResponse{
|
||||
ClientID: "expired-client",
|
||||
ClientSecret: "expired-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "no-expiry-client",
|
||||
ClientSecret: "no-expiry-secret",
|
||||
ClientSecretExpiresAt: 0,
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://noexpiry.example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials without expiry: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_CorruptedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
filePath := store.GetFilePath(providerURL)
|
||||
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write corrupted file: %v", err)
|
||||
}
|
||||
|
||||
_, err := store.Load(ctx, providerURL)
|
||||
if err == nil {
|
||||
t.Error("Expected error for corrupted JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_DirectoryCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
|
||||
store := NewFileStore(deepPath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
|
||||
err := store.Save(ctx, "https://example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save with nested directory: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after nested directory creation: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials from nested directory")
|
||||
}
|
||||
}
|
||||
@@ -173,7 +173,7 @@ func (m *FeatureManager) LoadFromEnv() {
|
||||
for name, flag := range flags {
|
||||
envVar := "FEATURE_" + name
|
||||
if value := os.Getenv(envVar); value != "" {
|
||||
enabled := strings.ToLower(value) == "true" || value == "1"
|
||||
enabled := strings.EqualFold(value, "true") || value == "1"
|
||||
flag.enabled.Store(enabled)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
|
||||
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if strings.ToLower(scope) != ScopeOfflineAccess {
|
||||
if !strings.EqualFold(scope, ScopeOfflineAccess) {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +147,8 @@ func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
|
||||
return p
|
||||
}
|
||||
case ProviderTypeKeycloak:
|
||||
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/auth/realms/") {
|
||||
// Match both Keycloak <17 (`/auth/realms/`) and 17+ (`/realms/`).
|
||||
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/realms/") {
|
||||
return p
|
||||
}
|
||||
case ProviderTypeAWSCognito:
|
||||
|
||||
@@ -225,10 +225,15 @@ func TestProviderRegistry_DetectProvider(t *testing.T) {
|
||||
expected: oktaProvider,
|
||||
},
|
||||
{
|
||||
name: "Keycloak provider detection",
|
||||
name: "Keycloak provider detection (legacy /auth/realms/)",
|
||||
issuerURL: "https://auth.example.com/auth/realms/master",
|
||||
expected: keycloakProvider,
|
||||
},
|
||||
{
|
||||
name: "Keycloak provider detection (modern /realms/, KC 17+)",
|
||||
issuerURL: "https://auth.example.com/realms/master",
|
||||
expected: keycloakProvider,
|
||||
},
|
||||
{
|
||||
name: "AWS Cognito provider detection",
|
||||
issuerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example",
|
||||
|
||||
@@ -18,16 +18,17 @@ func GetProviderWarnings(providerType ProviderType) []ProviderWarning {
|
||||
|
||||
switch providerType {
|
||||
case ProviderTypeGitHub:
|
||||
warnings = append(warnings, ProviderWarning{
|
||||
ProviderType: ProviderTypeGitHub,
|
||||
Level: "warning",
|
||||
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
|
||||
})
|
||||
warnings = append(warnings, ProviderWarning{
|
||||
ProviderType: ProviderTypeGitHub,
|
||||
Level: "info",
|
||||
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
|
||||
})
|
||||
warnings = append(warnings,
|
||||
ProviderWarning{
|
||||
ProviderType: ProviderTypeGitHub,
|
||||
Level: "warning",
|
||||
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
|
||||
},
|
||||
ProviderWarning{
|
||||
ProviderType: ProviderTypeGitHub,
|
||||
Level: "info",
|
||||
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
|
||||
})
|
||||
|
||||
case ProviderTypeAuth0:
|
||||
warnings = append(warnings, ProviderWarning{
|
||||
|
||||
@@ -116,7 +116,7 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
|
||||
// Continue to next attempt
|
||||
case <-ctx.Done():
|
||||
re.RecordFailure()
|
||||
return fmt.Errorf("retry cancelled: %w", ctx.Err())
|
||||
return fmt.Errorf("retry canceled: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,7 +301,7 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
allMetrics["summary"] = map[string]interface{}{
|
||||
summary := map[string]interface{}{
|
||||
"totalMechanisms": len(rm.mechanisms),
|
||||
"totalRequests": totalRequests,
|
||||
"totalSuccesses": totalSuccesses,
|
||||
@@ -310,8 +310,9 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
|
||||
|
||||
if totalRequests > 0 {
|
||||
successRate := float64(totalSuccesses) / float64(totalRequests) * 100
|
||||
allMetrics["summary"].(map[string]interface{})["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
|
||||
summary["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
|
||||
}
|
||||
allMetrics["summary"] = summary
|
||||
|
||||
return allMetrics
|
||||
}
|
||||
|
||||
@@ -223,7 +223,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelled(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
if execErr == nil {
|
||||
t.Error("Expected error when context is cancelled")
|
||||
t.Error("Expected error when context is canceled")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelledBeforeStart(t *testing
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when context is already cancelled")
|
||||
t.Error("Expected error when context is already canceled")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,7 +282,7 @@ func TestRetryExecutor_isRetryableError(t *testing.T) {
|
||||
{name: "timeout", err: errors.New("TIMEOUT"), expected: true}, // case insensitive
|
||||
{name: "EOF", err: errors.New("EOF"), expected: false},
|
||||
{name: "random error", err: errors.New("something else"), expected: false},
|
||||
{name: "context cancelled", err: context.Canceled, expected: false},
|
||||
{name: "context canceled", err: context.Canceled, expected: false},
|
||||
{name: "context deadline exceeded", err: context.DeadlineExceeded, expected: false},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
|
||||
// the fix for issue #132: token refresh path hardcoded the "email" claim and
|
||||
// ignored the configured userIdentifierClaim. Keycloak users without an email
|
||||
// claim (using sub or another identifier) were being kicked out on refresh
|
||||
// even though their initial login worked.
|
||||
//
|
||||
// The callback path (auth_flow.go) already honored userIdentifierClaim with
|
||||
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
|
||||
// after PR #100 (commit a316a98).
|
||||
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
name string
|
||||
userIdentifierClaim string
|
||||
expectedIdentifier string
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "sub claim configured, only sub present (Keycloak no-email case)",
|
||||
userIdentifierClaim: "sub",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-keycloak-12345",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user-uuid-keycloak-12345",
|
||||
},
|
||||
{
|
||||
name: "preferred_username configured, claim present",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"preferred_username": "alice",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "alice",
|
||||
},
|
||||
{
|
||||
name: "configured claim missing, falls back to sub",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "fallback-sub-id",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "fallback-sub-id",
|
||||
},
|
||||
{
|
||||
name: "email default, email present (backward compatibility)",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "email default, no email and no sub - refresh fails",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: false,
|
||||
expectedIdentifier: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager(
|
||||
"test-encryption-key-32-bytes-long!!",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
NewLogger("error"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("session manager: %v", err)
|
||||
}
|
||||
defer sessionManager.Shutdown()
|
||||
|
||||
capturedClaims := tt.claims
|
||||
tOidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
userIdentifierClaim: tt.userIdentifierClaim,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: &EnhancedMockTokenExchanger{
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
IDToken: "new-id-token-jwt",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
},
|
||||
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
|
||||
extractClaimsFunc: func(token string) (map[string]any, error) {
|
||||
return capturedClaims, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("get session: %v", err)
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
session.SetRefreshToken("initial-refresh-token")
|
||||
|
||||
refreshed := tOidc.refreshToken(rw, req, session)
|
||||
|
||||
if refreshed != tt.expectSuccess {
|
||||
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
|
||||
}
|
||||
|
||||
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
|
||||
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError reproduces and
|
||||
// verifies the fix for issue #134.
|
||||
//
|
||||
// Symptom (before fix): with a Redis backend wired into UniversalCache,
|
||||
// caching the parsed *parsedJWKS triggered:
|
||||
//
|
||||
// json: cannot unmarshal number 2251513...
|
||||
// into Go value of type float64
|
||||
//
|
||||
// Root cause: under yaegi, json.Marshal of a struct exposes unexported
|
||||
// fields with an X-prefixed name. parsedJWKS{ keys map[string]crypto.PublicKey }
|
||||
// thus serialized the inner *rsa.PublicKey, whose modulus *big.Int marshals
|
||||
// as a JSON number hundreds of digits long. On read, json.Unmarshal into
|
||||
// interface{} parses numbers as float64, which cannot represent that range.
|
||||
// The user saw the error log on every request even though auth still worked
|
||||
// (fallback path rebuilt the keys in memory).
|
||||
//
|
||||
// Fix: route both *JWKSet and *parsedJWKS through SetLocal/GetLocal — the
|
||||
// distributed backend never sees them.
|
||||
func TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "issue134:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
const kid = "azure-test-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
|
||||
}
|
||||
|
||||
var fetchCount int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&fetchCount, 1)
|
||||
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
infoBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(infoBuf, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
jwkCache := &JWKCache{cache: cache}
|
||||
ctx := context.Background()
|
||||
|
||||
pub1, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err, "first GetPublicKey should succeed")
|
||||
require.NotNil(t, pub1)
|
||||
gotRSA, ok := pub1.(*rsa.PublicKey)
|
||||
require.True(t, ok, "returned key should be *rsa.PublicKey, got %T", pub1)
|
||||
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N), "modulus must survive intact")
|
||||
assert.Equal(t, rsaKey.E, gotRSA.E, "exponent must survive intact")
|
||||
|
||||
pub2, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err, "second GetPublicKey should succeed")
|
||||
require.True(t, samePublicKey(pub1, pub2), "second call must return the same parsed key (cache hit)")
|
||||
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&fetchCount),
|
||||
"upstream JWKS endpoint must be hit exactly once; second call must be served from local cache")
|
||||
|
||||
errOutput := errBuf.String()
|
||||
assert.NotContains(t, errOutput, "Failed to deserialize",
|
||||
"deserialize error must not appear with the fix in place; got: %s", errOutput)
|
||||
assert.NotContains(t, errOutput, "into Go value of type float64",
|
||||
"float64 unmarshal error must not appear; got: %s", errOutput)
|
||||
|
||||
parsedKey := server.URL + parsedKeysSuffix
|
||||
jwksKey := server.URL
|
||||
for _, k := range []string{cache.prefixKey(parsedKey), cache.prefixKey(jwksKey)} {
|
||||
fullKey := redisCfg.RedisPrefix + k
|
||||
assert.False(t, mr.Exists(fullKey),
|
||||
"key %q must not exist in Redis (local-only caching); got %v", fullKey, mr.Keys())
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue134_StalePoisonedRedisDataIgnored verifies that pre-existing bad
|
||||
// data left in Redis under a JWK :parsed key from a prior buggy version is
|
||||
// ignored: the local-only fix never reads that key, so no log spam, and the
|
||||
// fallback path returns a real *rsa.PublicKey.
|
||||
func TestIssue134_StalePoisonedRedisDataIgnored(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "issue134stale:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
const kid = "azure-test-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pre-poison Redis with the kind of payload the old buggy path would have
|
||||
// produced (huge unquoted JSON number for the modulus). With the fix the
|
||||
// JWKCache must not even read this key.
|
||||
poisoned := []byte("\x01" + strings.Replace(
|
||||
`{"Xkeys":{"azure-test-kid":{"N":NUMBER,"E":65537}}}`,
|
||||
"NUMBER", rsaKey.N.String(), 1,
|
||||
))
|
||||
parsedRedisKey := redisCfg.RedisPrefix + "jwk:" + server.URL + parsedKeysSuffix
|
||||
require.NoError(t, mr.Set(parsedRedisKey, string(poisoned)))
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
jwkCache := &JWKCache{cache: cache}
|
||||
pub, err := jwkCache.GetPublicKey(context.Background(), server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pub)
|
||||
gotRSA, ok := pub.(*rsa.PublicKey)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N))
|
||||
|
||||
assert.NotContains(t, errBuf.String(), "Failed to deserialize",
|
||||
"poisoned Redis entry must not be touched; got error log: %s", errBuf.String())
|
||||
}
|
||||
|
||||
// TestIssue134_SetLocalGetLocalSkipBackend verifies the new SetLocal/GetLocal
|
||||
// pair never reads or writes the configured backend.
|
||||
func TestIssue134_SetLocalGetLocalSkipBackend(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "local:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 10,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
type unsafeShape struct {
|
||||
hidden map[string]interface{}
|
||||
}
|
||||
val := &unsafeShape{hidden: map[string]interface{}{"k": 1}}
|
||||
|
||||
require.NoError(t, cache.SetLocal("local-key", val, 1*time.Hour))
|
||||
|
||||
got, found := cache.GetLocal("local-key")
|
||||
require.True(t, found)
|
||||
assert.Same(t, val, got, "GetLocal must return the exact pointer stored, no JSON round-trip")
|
||||
|
||||
for _, k := range mr.Keys() {
|
||||
assert.NotContains(t, k, "local-key",
|
||||
"SetLocal must not write to Redis; found key %q (all keys: %v)", k, mr.Keys())
|
||||
}
|
||||
|
||||
cache.mu.Lock()
|
||||
delete(cache.items, "local-key")
|
||||
cache.lruList.Init()
|
||||
cache.currentSize = 0
|
||||
cache.currentMemory = 0
|
||||
cache.mu.Unlock()
|
||||
|
||||
_, found = cache.GetLocal("local-key")
|
||||
assert.False(t, found, "GetLocal must not fall back to backend after local cache cleared")
|
||||
}
|
||||
|
||||
// big2bytes returns the big-endian byte slice for a positive int.
|
||||
func big2bytes(e int) []byte {
|
||||
if e <= 0 {
|
||||
return []byte{}
|
||||
}
|
||||
var buf []byte
|
||||
for e > 0 {
|
||||
buf = append([]byte{byte(e & 0xff)}, buf...)
|
||||
e >>= 8
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// samePublicKey reports whether two crypto.PublicKey instances represent the
|
||||
// same RSA key, used to confirm cache hits return identical reconstructed
|
||||
// keys.
|
||||
func samePublicKey(a, b interface{}) bool {
|
||||
ar, ok1 := a.(*rsa.PublicKey)
|
||||
br, ok2 := b.(*rsa.PublicKey)
|
||||
if !ok1 || !ok2 {
|
||||
return false
|
||||
}
|
||||
return ar.N.Cmp(br.N) == 0 && ar.E == br.E
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
@@ -18,6 +19,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// parsedKeysSuffix marks the parallel UniversalCache entry that stores
|
||||
// pre-parsed public keys for a given JWKS URL.
|
||||
const parsedKeysSuffix = ":parsed"
|
||||
|
||||
// parsedJWKS holds keys decoded from a JWKSet, indexed by kid. Storing the
|
||||
// already-parsed crypto.PublicKey avoids re-running the DER/PEM round trip
|
||||
// on every JWT verification — a costly operation under the yaegi interpreter
|
||||
// that hosts Traefik plugins.
|
||||
type parsedJWKS struct {
|
||||
keys map[string]crypto.PublicKey
|
||||
}
|
||||
|
||||
// JWK represents a JSON Web Key as defined in RFC 7517.
|
||||
// It can represent different key types including RSA, EC, and symmetric keys.
|
||||
type JWK struct {
|
||||
@@ -49,6 +62,7 @@ type JWKCache struct {
|
||||
// JWKCacheInterface defines the contract for JWK caching implementations.
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error)
|
||||
Cleanup()
|
||||
Close()
|
||||
}
|
||||
@@ -62,9 +76,15 @@ func NewJWKCache() *JWKCache {
|
||||
}
|
||||
|
||||
// GetJWKS retrieves JWKS from cache or fetches from the remote URL if not cached.
|
||||
//
|
||||
// The entry is stored locally only via SetLocal/GetLocal. Going through a
|
||||
// distributed backend defeats the cache: JSON round-tripping turns *JWKSet
|
||||
// into map[string]interface{}, the type assertion below fails, and every
|
||||
// request refetches from the upstream. JWK rotation is rare and a per-replica
|
||||
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Check cache first
|
||||
if cachedValue, found := c.cache.Get(jwksURL); found {
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -74,7 +94,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if cachedValue, found := c.cache.Get(jwksURL); found {
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -91,11 +111,75 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
}
|
||||
|
||||
// Cache for 1 hour
|
||||
_ = c.cache.Set(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// GetPublicKey returns the parsed public key for a given kid, fetching and
|
||||
// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is
|
||||
// stored alongside the raw JWKSet under a sibling cache key with the same
|
||||
// 1-hour TTL, so both invalidate together when the upstream JWKS rotates.
|
||||
//
|
||||
// parsedJWKS is stored locally only (SetLocal/GetLocal). Its values are
|
||||
// crypto.PublicKey interfaces wrapping *rsa.PublicKey/*ecdsa.PublicKey,
|
||||
// which contain *big.Int that marshals to a hundreds-digit JSON number.
|
||||
// On a distributed backend round-trip, json.Unmarshal into interface{} would
|
||||
// try to fit that into float64 and fail with UnmarshalTypeError. Under yaegi
|
||||
// the unexported parsedJWKS.keys field is exposed via an X-prefixed name on
|
||||
// Marshal, leaking the modulus into the cached payload (issue #134).
|
||||
func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
parsedKey := jwksURL + parsedKeysSuffix
|
||||
if v, found := c.cache.GetLocal(parsedKey); found {
|
||||
if pj, ok := v.(*parsedJWKS); ok {
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jwks, err := c.GetJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pj := buildParsedJWKS(jwks)
|
||||
_ = c.cache.SetLocal(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
// buildParsedJWKS pre-parses every JWK in the set into the matching
|
||||
// crypto.PublicKey, indexed by kid. Errors on individual keys are skipped so
|
||||
// a single bad key does not block the rest of the keyset.
|
||||
func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
|
||||
out := make(map[string]crypto.PublicKey, len(jwks.Keys))
|
||||
for i := range jwks.Keys {
|
||||
k := &jwks.Keys[i]
|
||||
if k.Kid == "" {
|
||||
continue
|
||||
}
|
||||
var pub crypto.PublicKey
|
||||
var err error
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
pub, err = k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
pub, err = k.ToECDSAPublicKey()
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out[k.Kid] = pub
|
||||
}
|
||||
return &parsedJWKS{keys: out}
|
||||
}
|
||||
|
||||
// Cleanup is a no-op as cleanup is handled by UniversalCache
|
||||
func (c *JWKCache) Cleanup() {
|
||||
// Handled internally by UniversalCache
|
||||
@@ -213,9 +297,9 @@ func (jwk *JWK) ToECDSAPublicKey() (*ecdsa.PublicKey, error) {
|
||||
// GetKey finds a key by its ID (kid) in the JWKSet.
|
||||
// Returns nil if no key with the given ID is found.
|
||||
func (jwks *JWKSet) GetKey(kid string) *JWK {
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
return &key
|
||||
for i := range jwks.Keys {
|
||||
if jwks.Keys[i].Kid == kid {
|
||||
return &jwks.Keys[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -120,7 +120,7 @@ func getReplayCacheStats() (size int, maxSize int) {
|
||||
// Parameters:
|
||||
// - ctx: Parent context for cancellation
|
||||
// - logger: Logger for debug output (can be nil)
|
||||
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
|
||||
func startReplayCacheCleanup(_ context.Context, logger *Logger) {
|
||||
registry := GetGlobalTaskRegistry()
|
||||
|
||||
// Define the cleanup task function
|
||||
@@ -528,6 +528,21 @@ func verifyNotBefore(notBefore float64) error {
|
||||
// - An error if the key parsing fails, the algorithm is unsupported,
|
||||
// or the signature verification fails
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
return verifySignatureWithKey(tokenString, pubKey, alg)
|
||||
}
|
||||
|
||||
// verifySignatureWithKey verifies a JWT signature using an already-parsed
|
||||
// public key, skipping the PEM-encode/decode round trip that verifySignature
|
||||
// performs. This is the hot path used by VerifyJWTSignatureAndClaims.
|
||||
func verifySignatureWithKey(tokenString string, pubKey crypto.PublicKey, alg string) error {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
@@ -537,14 +552,6 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
var hashFunc crypto.Hash
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
|
||||
@@ -0,0 +1,502 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// This file implements OIDC Backchannel Logout (OpenID Connect Back-Channel Logout 1.0)
|
||||
// and Front-Channel Logout (OpenID Connect Front-Channel Logout 1.0) functionality.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// logoutTokenType is the expected typ claim for logout tokens
|
||||
// #nosec G101 -- This is a JWT type claim value from OIDC spec, not a credential
|
||||
logoutTokenType = "logout+jwt"
|
||||
|
||||
// sessionInvalidationTTL is how long to remember invalidated sessions
|
||||
// Should be at least as long as your session max age
|
||||
sessionInvalidationTTL = 25 * time.Hour
|
||||
)
|
||||
|
||||
// LogoutTokenClaims represents the claims in an OIDC logout token
|
||||
// as defined in OpenID Connect Back-Channel Logout 1.0
|
||||
type LogoutTokenClaims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audience interface{} `json:"aud"` // Can be string or []string
|
||||
IssuedAt int64 `json:"iat"`
|
||||
JTI string `json:"jti"`
|
||||
Events map[string]interface{} `json:"events"`
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"` // Must NOT be present
|
||||
}
|
||||
|
||||
// handleBackchannelLogout processes OIDC Backchannel Logout requests.
|
||||
// It accepts POST requests with a logout_token parameter containing a JWT
|
||||
// that identifies which session(s) to terminate.
|
||||
//
|
||||
// According to OpenID Connect Back-Channel Logout 1.0:
|
||||
// - The logout_token is a JWT signed by the IdP
|
||||
// - It contains either a 'sid' (session ID) or 'sub' (subject) claim to identify the session
|
||||
// - The RP must validate the token and invalidate the matching session(s)
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer
|
||||
// - req: The HTTP request containing the logout_token
|
||||
func (t *TraefikOidc) handleBackchannelLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
t.logger.Debug("Processing backchannel logout request")
|
||||
|
||||
// Backchannel logout must be POST
|
||||
if req.Method != http.MethodPost {
|
||||
t.logger.Errorf("Backchannel logout: invalid method %s, expected POST", req.Method)
|
||||
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse form data to get logout_token
|
||||
if err := req.ParseForm(); err != nil {
|
||||
t.logger.Errorf("Backchannel logout: failed to parse form: %v", err)
|
||||
http.Error(rw, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logoutToken := req.FormValue("logout_token")
|
||||
if logoutToken == "" {
|
||||
// Also try reading from request body as raw JWT
|
||||
body, err := io.ReadAll(io.LimitReader(req.Body, 64*1024)) // 64KB limit
|
||||
if err == nil && len(body) > 0 {
|
||||
logoutToken = string(body)
|
||||
}
|
||||
}
|
||||
|
||||
if logoutToken == "" {
|
||||
t.logger.Error("Backchannel logout: missing logout_token")
|
||||
http.Error(rw, "logout_token required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and validate the logout token
|
||||
claims, err := t.validateLogoutToken(logoutToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Backchannel logout: token validation failed: %v", err)
|
||||
// Return 400 for invalid token per spec
|
||||
http.Error(rw, "Invalid logout token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate session(s) based on sid or sub
|
||||
if err := t.invalidateSession(claims.SessionID, claims.Subject); err != nil {
|
||||
t.logger.Errorf("Backchannel logout: failed to invalidate session: %v", err)
|
||||
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Infof("Backchannel logout: successfully invalidated session (sid=%s, sub=%s)",
|
||||
claims.SessionID, claims.Subject)
|
||||
|
||||
// Return 200 OK with empty body per spec
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// handleFrontchannelLogout processes OIDC Front-Channel Logout requests.
|
||||
// It accepts GET requests with 'iss' and 'sid' query parameters that identify
|
||||
// which session to terminate. The IdP typically loads this URL in an iframe.
|
||||
//
|
||||
// According to OpenID Connect Front-Channel Logout 1.0:
|
||||
// - The request contains 'iss' (issuer) and optionally 'sid' (session ID)
|
||||
// - The RP should clear the session and return a response (typically empty or image)
|
||||
// - The response must be cacheable to allow the IdP to load it in an iframe
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer
|
||||
// - req: The HTTP request containing iss and sid parameters
|
||||
func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
t.logger.Debug("Processing front-channel logout request")
|
||||
|
||||
// Front-channel logout should be GET
|
||||
if req.Method != http.MethodGet {
|
||||
t.logger.Errorf("Front-channel logout: invalid method %s, expected GET", req.Method)
|
||||
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Get iss and sid from query parameters
|
||||
iss := req.URL.Query().Get("iss")
|
||||
sid := req.URL.Query().Get("sid")
|
||||
|
||||
// Validate issuer matches our expected issuer
|
||||
t.metadataMu.RLock()
|
||||
expectedIssuer := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if iss != "" && iss != expectedIssuer {
|
||||
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
|
||||
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Must have at least sid for front-channel logout
|
||||
if sid == "" {
|
||||
t.logger.Error("Front-channel logout: missing sid parameter")
|
||||
http.Error(rw, "sid parameter required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate the session
|
||||
if err := t.invalidateSession(sid, ""); err != nil {
|
||||
t.logger.Errorf("Front-channel logout: failed to invalidate session: %v", err)
|
||||
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Infof("Front-channel logout: successfully invalidated session (sid=%s)", sid)
|
||||
|
||||
// Return a minimal HTML response that's suitable for iframe loading
|
||||
// Set headers to allow embedding and caching
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.Header().Set("Cache-Control", "no-cache, no-store")
|
||||
rw.Header().Set("Pragma", "no-cache")
|
||||
// Allow embedding in iframes from any origin (required for front-channel logout)
|
||||
rw.Header().Del("X-Frame-Options")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write([]byte("<!DOCTYPE html><html><head><title>Logged Out</title></head><body></body></html>"))
|
||||
}
|
||||
|
||||
// validateLogoutToken parses and validates a logout token JWT.
|
||||
// It verifies the token signature, issuer, audience, and required claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT logout token
|
||||
//
|
||||
// Returns:
|
||||
// - The parsed logout token claims
|
||||
// - An error if validation fails
|
||||
func (t *TraefikOidc) validateLogoutToken(tokenString string) (*LogoutTokenClaims, error) {
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(tokenString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse logout token: %w", err)
|
||||
}
|
||||
|
||||
// Check token type if present
|
||||
if typ, ok := jwt.Header["typ"].(string); ok {
|
||||
// The typ should be "logout+jwt" or omitted
|
||||
if typ != "" && typ != logoutTokenType && typ != "JWT" {
|
||||
return nil, fmt.Errorf("invalid token type: %s", typ)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify signature only (not standard claims - logout tokens don't have 'exp')
|
||||
if err := t.verifyLogoutTokenSignature(jwt, tokenString); err != nil {
|
||||
return nil, fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
// Extract claims
|
||||
claims := &LogoutTokenClaims{}
|
||||
claimsJSON, err := json.Marshal(jwt.Claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal claims: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(claimsJSON, claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
||||
}
|
||||
|
||||
// Validate required claims
|
||||
t.metadataMu.RLock()
|
||||
expectedIssuer := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
// Validate issuer
|
||||
if claims.Issuer != expectedIssuer {
|
||||
return nil, fmt.Errorf("issuer mismatch: got %s, expected %s", claims.Issuer, expectedIssuer)
|
||||
}
|
||||
|
||||
// Validate audience (must contain our client_id)
|
||||
if !t.validateLogoutTokenAudience(claims.Audience) {
|
||||
return nil, fmt.Errorf("audience validation failed")
|
||||
}
|
||||
|
||||
// Validate iat (issued at) - must be present and not too old
|
||||
if claims.IssuedAt == 0 {
|
||||
return nil, fmt.Errorf("missing iat claim")
|
||||
}
|
||||
iatTime := time.Unix(claims.IssuedAt, 0)
|
||||
// Allow up to 5 minutes clock skew and 10 minutes token age
|
||||
if time.Since(iatTime) > 15*time.Minute {
|
||||
return nil, fmt.Errorf("logout token too old: issued at %v", iatTime)
|
||||
}
|
||||
// Token should not be from the future (with 5 min clock skew tolerance)
|
||||
if iatTime.After(time.Now().Add(5 * time.Minute)) {
|
||||
return nil, fmt.Errorf("logout token issued in the future: %v", iatTime)
|
||||
}
|
||||
|
||||
// Validate events claim - must contain the logout event
|
||||
if claims.Events == nil {
|
||||
return nil, fmt.Errorf("missing events claim")
|
||||
}
|
||||
if _, ok := claims.Events["http://schemas.openid.net/event/backchannel-logout"]; !ok {
|
||||
return nil, fmt.Errorf("missing backchannel-logout event in events claim")
|
||||
}
|
||||
|
||||
// Validate that nonce is NOT present (per spec)
|
||||
if claims.Nonce != "" {
|
||||
return nil, fmt.Errorf("nonce claim must not be present in logout token")
|
||||
}
|
||||
|
||||
// Must have either sid or sub (or both)
|
||||
if claims.SessionID == "" && claims.Subject == "" {
|
||||
return nil, fmt.Errorf("logout token must contain either sid or sub claim")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// validateLogoutTokenAudience checks if the logout token audience contains our client_id
|
||||
func (t *TraefikOidc) validateLogoutTokenAudience(aud interface{}) bool {
|
||||
switch v := aud.(type) {
|
||||
case string:
|
||||
return v == t.clientID
|
||||
case []interface{}:
|
||||
for _, a := range v {
|
||||
if s, ok := a.(string); ok && s == t.clientID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, a := range v {
|
||||
if a == t.clientID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// verifyLogoutTokenSignature verifies only the signature of a logout token.
|
||||
// Unlike VerifyJWTSignatureAndClaims, this does NOT validate standard claims like 'exp'
|
||||
// because logout tokens don't have an expiration claim per OIDC Back-Channel Logout spec.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwt: The parsed JWT structure
|
||||
// - tokenString: The raw token string for signature verification
|
||||
//
|
||||
// Returns:
|
||||
// - An error if signature verification fails
|
||||
func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) error {
|
||||
t.logger.Debug("Verifying logout token signature")
|
||||
|
||||
// Read jwksURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
jwksURL := t.jwksURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
if jwks == nil {
|
||||
return fmt.Errorf("JWKS is nil, cannot verify token")
|
||||
}
|
||||
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
if !ok || kid == "" {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
}
|
||||
|
||||
alg, ok := jwt.Header["alg"].(string)
|
||||
if !ok || alg == "" {
|
||||
return fmt.Errorf("missing algorithm in token header")
|
||||
}
|
||||
|
||||
// Find the matching key in JWKS
|
||||
var matchingKey *JWK
|
||||
for i := range jwks.Keys {
|
||||
if jwks.Keys[i].Kid == kid {
|
||||
matchingKey = &jwks.Keys[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchingKey == nil {
|
||||
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
publicKeyPEM, err := jwkToPEM(matchingKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySignature(tokenString, publicKeyPEM, alg); err != nil {
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
t.logger.Debug("Logout token signature verified successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// invalidateSession marks a session as invalidated in the session invalidation cache.
|
||||
// It stores entries by both sid and sub if available.
|
||||
//
|
||||
// Parameters:
|
||||
// - sid: The session ID to invalidate (from the 'sid' claim)
|
||||
// - sub: The subject to invalidate (from the 'sub' claim)
|
||||
//
|
||||
// Returns:
|
||||
// - An error if the invalidation fails
|
||||
func (t *TraefikOidc) invalidateSession(sid, sub string) error {
|
||||
if t.sessionInvalidationCache == nil {
|
||||
return fmt.Errorf("session invalidation cache not initialized")
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
|
||||
// Store by session ID
|
||||
if sid != "" {
|
||||
key := t.buildSessionInvalidationKey("sid", sid)
|
||||
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
|
||||
t.logger.Debugf("Invalidated session by sid: %s", sid)
|
||||
}
|
||||
|
||||
// Store by subject (invalidates all sessions for this user)
|
||||
if sub != "" {
|
||||
key := t.buildSessionInvalidationKey("sub", sub)
|
||||
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
|
||||
t.logger.Debugf("Invalidated session by sub: %s", sub)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isSessionInvalidated checks if a session has been invalidated via backchannel
|
||||
// or front-channel logout.
|
||||
//
|
||||
// Parameters:
|
||||
// - sid: The session ID to check
|
||||
// - sub: The subject to check
|
||||
// - sessionCreatedAt: When the session was created (to compare against invalidation time)
|
||||
//
|
||||
// Returns:
|
||||
// - true if the session has been invalidated, false otherwise
|
||||
func (t *TraefikOidc) isSessionInvalidated(sid, sub string, sessionCreatedAt time.Time) bool {
|
||||
if t.sessionInvalidationCache == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Truncate session creation time to seconds for fair comparison with Unix timestamps
|
||||
sessionCreatedAtSec := sessionCreatedAt.Truncate(time.Second)
|
||||
|
||||
// Check by session ID first (more specific)
|
||||
if sid != "" {
|
||||
key := t.buildSessionInvalidationKey("sid", sid)
|
||||
if val, found := t.sessionInvalidationCache.Get(key); found {
|
||||
if invalidatedAt, ok := val.(int64); ok {
|
||||
// Session was invalidated at or after it was created
|
||||
invalidationTime := time.Unix(invalidatedAt, 0)
|
||||
if !invalidationTime.Before(sessionCreatedAtSec) {
|
||||
t.logger.Debugf("Session invalidated by sid: %s", sid)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check by subject (all sessions for this user)
|
||||
if sub != "" {
|
||||
key := t.buildSessionInvalidationKey("sub", sub)
|
||||
if val, found := t.sessionInvalidationCache.Get(key); found {
|
||||
if invalidatedAt, ok := val.(int64); ok {
|
||||
// Sessions for this subject created at or before invalidation are invalid
|
||||
invalidationTime := time.Unix(invalidatedAt, 0)
|
||||
if !invalidationTime.Before(sessionCreatedAtSec) {
|
||||
t.logger.Debugf("Session invalidated by sub: %s", sub)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// buildSessionInvalidationKey creates a cache key for session invalidation
|
||||
func (t *TraefikOidc) buildSessionInvalidationKey(keyType, value string) string {
|
||||
return fmt.Sprintf("session_invalidation:%s:%s", keyType, value)
|
||||
}
|
||||
|
||||
// extractSessionInfo extracts sid and sub from an ID token for session tracking
|
||||
func (t *TraefikOidc) extractSessionInfo(idToken string) (sid, sub string, createdAt time.Time) {
|
||||
if idToken == "" {
|
||||
return "", "", time.Time{}
|
||||
}
|
||||
|
||||
jwt, err := parseJWT(idToken)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}
|
||||
}
|
||||
|
||||
// Extract sid (session ID)
|
||||
if sidVal, ok := jwt.Claims["sid"].(string); ok {
|
||||
sid = sidVal
|
||||
}
|
||||
|
||||
// Extract sub (subject)
|
||||
if subVal, ok := jwt.Claims["sub"].(string); ok {
|
||||
sub = subVal
|
||||
}
|
||||
|
||||
// Extract iat for session creation time
|
||||
if iatVal, ok := jwt.Claims["iat"].(float64); ok {
|
||||
createdAt = time.Unix(int64(iatVal), 0)
|
||||
} else {
|
||||
// Default to now if iat not present
|
||||
createdAt = time.Now()
|
||||
}
|
||||
|
||||
return sid, sub, createdAt
|
||||
}
|
||||
|
||||
// determineLogoutPath checks if the given path matches any logout URL
|
||||
func (t *TraefikOidc) determineLogoutPath(path string) string {
|
||||
// Check backchannel logout path
|
||||
if t.backchannelLogoutPath != "" && path == t.backchannelLogoutPath {
|
||||
return "backchannel"
|
||||
}
|
||||
|
||||
// Check front-channel logout path
|
||||
if t.frontchannelLogoutPath != "" && path == t.frontchannelLogoutPath {
|
||||
return "frontchannel"
|
||||
}
|
||||
|
||||
// Check regular logout path (for RP-initiated logout)
|
||||
if path == t.logoutURLPath {
|
||||
return "rp"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// normalizeLogoutPath ensures logout paths start with / and prevents open redirects
|
||||
func normalizeLogoutPath(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
// Prevent open redirect: ensure second character is not / or \
|
||||
// This prevents URLs like //example.com or /\example.com from being treated as absolute URLs
|
||||
if len(path) > 1 && (path[1] == '/' || path[1] == '\\') {
|
||||
// Strip leading slashes/backslashes and re-normalize
|
||||
path = strings.TrimLeft(path, "/\\")
|
||||
if path != "" {
|
||||
path = "/" + path
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
+1660
File diff suppressed because it is too large
Load Diff
@@ -113,12 +113,26 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
}
|
||||
// Setup HTTP client
|
||||
caPool, err := config.loadCACertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load CA certificates: %w", err)
|
||||
}
|
||||
if config.InsecureSkipVerify {
|
||||
logger.Errorf("SECURITY WARNING: InsecureSkipVerify is enabled for the OIDC provider. TLS certificate verification is DISABLED. Do not use in production.")
|
||||
}
|
||||
var httpClient *http.Client
|
||||
if config.HTTPClient != nil {
|
||||
httpClient = config.HTTPClient
|
||||
} else {
|
||||
httpClient = CreateDefaultHTTPClient()
|
||||
defaultCfg := DefaultHTTPClientConfig()
|
||||
defaultCfg.RootCAs = caPool
|
||||
defaultCfg.InsecureSkipVerify = config.InsecureSkipVerify
|
||||
httpClient = CreatePooledHTTPClient(defaultCfg)
|
||||
}
|
||||
tokenCfg := TokenHTTPClientConfig()
|
||||
tokenCfg.RootCAs = caPool
|
||||
tokenCfg.InsecureSkipVerify = config.InsecureSkipVerify
|
||||
tokenHTTPClient := CreatePooledHTTPClient(tokenCfg)
|
||||
goroutineWG := &sync.WaitGroup{}
|
||||
cacheManager := GetGlobalCacheManagerWithConfig(goroutineWG, config)
|
||||
|
||||
@@ -199,7 +213,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: cacheManager.GetSharedTokenCache(),
|
||||
httpClient: httpClient,
|
||||
tokenHTTPClient: CreateTokenHTTPClient(),
|
||||
tokenHTTPClient: tokenHTTPClient,
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
|
||||
@@ -212,16 +226,30 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return 60 * time.Second
|
||||
}(),
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
cancelFunc: cancelFunc,
|
||||
suppressDiagnosticLogs: isTestMode(),
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
dcrConfig: config.DynamicClientRegistration,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
minimalHeaders: config.MinimalHeaders,
|
||||
maxRefreshTokenAge: func() time.Duration {
|
||||
// 0 (or unset) disables the heuristic; negative is rejected by Validate.
|
||||
if config.MaxRefreshTokenAgeSeconds > 0 {
|
||||
return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second
|
||||
}
|
||||
return 0
|
||||
}(),
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
cancelFunc: cancelFunc,
|
||||
suppressDiagnosticLogs: isTestMode(),
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
dcrConfig: config.DynamicClientRegistration,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
minimalHeaders: config.MinimalHeaders,
|
||||
stripAuthCookies: config.StripAuthCookies,
|
||||
enableBackchannelLogout: config.EnableBackchannelLogout,
|
||||
enableFrontchannelLogout: config.EnableFrontchannelLogout,
|
||||
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
|
||||
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
|
||||
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
|
||||
refreshResultCache: cacheManager.GetSharedRefreshResultCache(),
|
||||
}
|
||||
|
||||
// Log audience configuration
|
||||
@@ -240,6 +268,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
tokenResilienceConfig := DefaultTokenResilienceConfig()
|
||||
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
|
||||
|
||||
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
|
||||
// call, preventing the thundering herd that yields invalid_grant when the IdP
|
||||
// rotates refresh tokens (Zitadel/Authentik default).
|
||||
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
|
||||
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
@@ -287,17 +320,22 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
|
||||
startReplayCacheCleanup(pluginCtx, logger)
|
||||
|
||||
// Start memory monitoring for leak detection and performance insights
|
||||
// Start memory monitoring for leak detection and performance insights.
|
||||
// The interval is clamped to MinMemoryMonitorInterval (30s) inside
|
||||
// StartMonitoring; tests that need deterministic sampling should call
|
||||
// MemoryMonitor.Refresh() directly instead of waiting on a fast ticker.
|
||||
memoryMonitor := GetGlobalMemoryMonitor()
|
||||
monitorInterval := 60 * time.Second
|
||||
if isTestMode() {
|
||||
monitorInterval = 100 * time.Millisecond // Fast interval for tests
|
||||
}
|
||||
memoryMonitor.StartMonitoring(pluginCtx, monitorInterval)
|
||||
memoryMonitor.StartMonitoring(pluginCtx, DefaultMemoryMonitorInterval)
|
||||
logger.Debug("Started global memory monitoring")
|
||||
|
||||
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
|
||||
|
||||
// Log callback URL configuration to help diagnose redirect loop issues.
|
||||
// If callbackURL is a full URL instead of a path, the callback matching
|
||||
// in ServeHTTP will silently fail because req.URL.Path is compared directly.
|
||||
logger.Debugf("TraefikOidc.New: callbackURL (redirURLPath) configured as: %q", t.redirURLPath)
|
||||
logger.Debugf("TraefikOidc.New: logoutURLPath configured as: %q", t.logoutURLPath)
|
||||
|
||||
t.providerURL = config.ProviderURL
|
||||
|
||||
// Use singleton resource manager for metadata initialization
|
||||
@@ -433,6 +471,19 @@ func (t *TraefikOidc) performDynamicClientRegistration() {
|
||||
t.dcrConfig,
|
||||
t.providerURL,
|
||||
)
|
||||
|
||||
// Set up storage backend for credentials persistence
|
||||
if t.dcrConfig.PersistCredentials {
|
||||
cacheManager := GetGlobalCacheManagerWithConfig(t.goroutineWG, nil)
|
||||
store, err := NewDCRCredentialsStore(t.dcrConfig, cacheManager, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to create DCR credentials store: %v", err)
|
||||
// Continue without persistence - registration will still work
|
||||
} else {
|
||||
t.dynamicClientRegistrar.SetStore(store)
|
||||
t.logger.Debugf("DCR credentials store initialized with backend: %s", t.dcrConfig.StorageBackend)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get registration endpoint (from metadata or config override)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config Marshalling Tests
|
||||
// Config Marshaling Tests
|
||||
|
||||
func TestConfig_MarshalJSON(t *testing.T) {
|
||||
config := &Config{
|
||||
|
||||
+535
-42
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTP_EventStream tests the event-stream bypass functionality
|
||||
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
|
||||
// handshake must skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but it must STILL require an authenticated session, otherwise
|
||||
// any caller could reach the backend by setting Accept: text/event-stream.
|
||||
func TestServeHTTP_EventStream(t *testing.T) {
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
|
||||
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
|
||||
}
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
// Build an authenticated session and inject its cookies onto req.
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated SSE request to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
|
||||
// handshake bypasses the OIDC redirect (clients can't follow it) but the
|
||||
// session must already be authenticated, otherwise the backend is exposed
|
||||
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
|
||||
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}))
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected event-stream request to bypass OIDC")
|
||||
}
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
// Mixed-case + multi-token Connection header to exercise parsing.
|
||||
req.Header.Set("Connection", "keep-alive, Upgrade")
|
||||
req.Header.Set("Upgrade", "WebSocket")
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("ws-user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "ws-user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
|
||||
// Sanity: requests without Upgrade headers must NOT hit the WS
|
||||
// bypass branch (otherwise the new code path could short-circuit
|
||||
// normal authentication).
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("backend must not be called for unauthenticated plain HTTP")
|
||||
}))
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code == http.StatusOK {
|
||||
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
|
||||
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "successful authorization with email",
|
||||
setupSession: func() *MockSessionData {
|
||||
session := &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "no email triggers reauth",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "",
|
||||
userIdentifier: "",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "roles and groups authorization",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "unauthorized role/group returns 403",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "template headers processing",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "OPTIONS request with CORS",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
manager: &SessionManager{logger: NewLogger("debug")},
|
||||
}
|
||||
// Copy values from mock to concrete session
|
||||
concreteSession.SetEmail(session.email)
|
||||
concreteSession.SetUserIdentifier(session.userIdentifier)
|
||||
concreteSession.SetIDToken(session.idToken)
|
||||
concreteSession.SetAccessToken(session.accessToken)
|
||||
concreteSession.SetRefreshToken(session.refreshToken)
|
||||
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
|
||||
// MockSessionData is a test implementation of SessionData interface
|
||||
type MockSessionData struct {
|
||||
email string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
userIdentifier string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
}
|
||||
|
||||
func (m *MockSessionData) GetEmail() string { return m.email }
|
||||
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
|
||||
func (m *MockSessionData) GetIDToken() string { return m.idToken }
|
||||
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
|
||||
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
|
||||
func (m *MockSessionData) SetEmail(email string) { m.email = email }
|
||||
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
|
||||
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
|
||||
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
|
||||
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
|
||||
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set up session data
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Call processAuthorizedRequest directly
|
||||
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -710,3 +862,344 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies tests the stripAuthCookies configuration option.
|
||||
// This addresses GitHub issue #122 - OIDC cookies bloating backend requests.
|
||||
func TestStripAuthCookies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripAuthCookies bool
|
||||
expectOIDCCookies bool
|
||||
expectAppCookies bool
|
||||
}{
|
||||
{
|
||||
name: "stripAuthCookies=false (default) forwards all cookies",
|
||||
stripAuthCookies: false,
|
||||
expectOIDCCookies: true,
|
||||
expectAppCookies: true,
|
||||
},
|
||||
{
|
||||
name: "stripAuthCookies=true strips OIDC cookies but keeps app cookies",
|
||||
stripAuthCookies: true,
|
||||
expectOIDCCookies: false,
|
||||
expectAppCookies: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
cookiePrefix := sessionManager.GetCookiePrefix()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: tt.stripAuthCookies,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Get a valid session first (before adding fake cookies)
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Now add OIDC session cookies (simulating what the browser would send)
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_1", Value: "chunk1"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "r", Value: "refresh-token"})
|
||||
|
||||
// Add non-OIDC application cookies (these must always pass through)
|
||||
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "app-session-id"})
|
||||
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
// Check for OIDC cookies in captured cookies
|
||||
hasOIDCCookie := false
|
||||
hasAppSession := false
|
||||
hasTheme := false
|
||||
for _, c := range capturedCookies {
|
||||
if len(c.Name) >= len(cookiePrefix) && c.Name[:len(cookiePrefix)] == cookiePrefix {
|
||||
hasOIDCCookie = true
|
||||
}
|
||||
if c.Name == "my_app_session" {
|
||||
hasAppSession = true
|
||||
}
|
||||
if c.Name == "theme" {
|
||||
hasTheme = true
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectOIDCCookies && !hasOIDCCookie {
|
||||
t.Error("expected OIDC cookies to be forwarded to backend")
|
||||
}
|
||||
if !tt.expectOIDCCookies && hasOIDCCookie {
|
||||
t.Error("expected OIDC cookies to be stripped before forwarding to backend")
|
||||
}
|
||||
|
||||
if tt.expectAppCookies && !hasAppSession {
|
||||
t.Error("expected my_app_session cookie to be forwarded to backend")
|
||||
}
|
||||
if tt.expectAppCookies && !hasTheme {
|
||||
t.Error("expected theme cookie to be forwarded to backend")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies_NoCookies verifies stripping works when the request has no cookies.
|
||||
func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
if len(capturedCookies) != 0 {
|
||||
t.Errorf("expected no cookies, got %d", len(capturedCookies))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies_OnlyOIDCCookies verifies that when all cookies are OIDC cookies,
|
||||
// the Cookie header is empty after stripping.
|
||||
func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
var capturedCookieHeader string
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookieHeader = r.Header.Get("Cookie")
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
cookiePrefix := sessionManager.GetCookiePrefix()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only OIDC cookies
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
if len(capturedCookies) != 0 {
|
||||
t.Errorf("expected all cookies to be stripped, got %d", len(capturedCookies))
|
||||
}
|
||||
if capturedCookieHeader != "" {
|
||||
t.Errorf("expected empty Cookie header, got %q", capturedCookieHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies_OnlyAppCookies verifies that non-OIDC cookies pass through
|
||||
// untouched when stripping is enabled.
|
||||
func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only non-OIDC cookies
|
||||
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "abc123"})
|
||||
req.AddCookie(&http.Cookie{Name: "lang", Value: "en"})
|
||||
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
if len(capturedCookies) != 3 {
|
||||
t.Errorf("expected 3 cookies, got %d", len(capturedCookies))
|
||||
}
|
||||
|
||||
cookieNames := make(map[string]bool)
|
||||
for _, c := range capturedCookies {
|
||||
cookieNames[c.Name] = true
|
||||
}
|
||||
for _, expected := range []string{"my_app_session", "lang", "theme"} {
|
||||
if !cookieNames[expected] {
|
||||
t.Errorf("expected cookie %q to be forwarded", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies_CustomPrefix verifies stripping works with a custom cookie prefix.
|
||||
func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create session manager with custom prefix
|
||||
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "myapp_oidc_", 0, NewLogger("debug"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
customPrefix := sm.GetCookiePrefix()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add cookies with the custom prefix (should be stripped)
|
||||
req.AddCookie(&http.Cookie{Name: customPrefix + "m", Value: "session-data"})
|
||||
req.AddCookie(&http.Cookie{Name: customPrefix + "s_0", Value: "chunk0"})
|
||||
|
||||
// Add default-prefix cookie (should NOT be stripped — different prefix)
|
||||
req.AddCookie(&http.Cookie{Name: "_oidc_raczylo_m", Value: "other-session"})
|
||||
|
||||
// Add app cookie (should NOT be stripped)
|
||||
req.AddCookie(&http.Cookie{Name: "my_app", Value: "val"})
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
cookieNames := make(map[string]bool)
|
||||
for _, c := range capturedCookies {
|
||||
cookieNames[c.Name] = true
|
||||
}
|
||||
|
||||
// Custom prefix cookies should be stripped
|
||||
if cookieNames[customPrefix+"m"] {
|
||||
t.Errorf("expected cookie %q to be stripped", customPrefix+"m")
|
||||
}
|
||||
if cookieNames[customPrefix+"s_0"] {
|
||||
t.Errorf("expected cookie %q to be stripped", customPrefix+"s_0")
|
||||
}
|
||||
|
||||
// Default prefix cookie should pass through (different prefix)
|
||||
if !cookieNames["_oidc_raczylo_m"] {
|
||||
t.Error("expected _oidc_raczylo_m cookie to pass through (different prefix)")
|
||||
}
|
||||
|
||||
// App cookie should pass through
|
||||
if !cookieNames["my_app"] {
|
||||
t.Error("expected my_app cookie to pass through")
|
||||
}
|
||||
}
|
||||
|
||||
+41
-15
@@ -208,6 +208,32 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.Err != nil {
|
||||
return nil, m.Err
|
||||
}
|
||||
if m.JWKS == nil {
|
||||
return nil, fmt.Errorf("JWKS is nil")
|
||||
}
|
||||
for i := range m.JWKS.Keys {
|
||||
k := &m.JWKS.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) Cleanup() {
|
||||
// Mock cleanup is a no-op - we don't want to destroy the mock JWKS data
|
||||
// Real cleanup is for expired entries, not resetting all data
|
||||
@@ -554,7 +580,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -577,7 +603,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
// even if session.SetAuthenticated(true) was called.
|
||||
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -634,7 +660,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -652,7 +678,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -680,7 +706,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -715,7 +741,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
||||
},
|
||||
@@ -746,7 +772,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(validToken)
|
||||
session.SetIDToken(validToken) // Ensure ID token is also set
|
||||
session.SetRefreshToken("should-not-be-used-refresh-token")
|
||||
@@ -766,7 +792,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -788,7 +814,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -2153,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAccessToken(expiredToken)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
},
|
||||
expectedPath: "/original/path",
|
||||
},
|
||||
@@ -2730,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2756,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2783,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
@@ -2803,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2825,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
|
||||
+24
-15
@@ -9,13 +9,18 @@ import (
|
||||
// LazyBackgroundTask wraps BackgroundTask to provide delayed initialization.
|
||||
// This prevents memory leaks from unnecessary background tasks by starting
|
||||
// them only when actually needed, reducing resource usage in idle scenarios.
|
||||
//
|
||||
// Lifecycle is one-shot: once Stop has been called the task cannot be
|
||||
// restarted. The underlying BackgroundTask uses sync.Once for Start and
|
||||
// refuses to re-run after Stop, so restart is not supported by design.
|
||||
type LazyBackgroundTask struct {
|
||||
// BackgroundTask is the underlying task implementation
|
||||
*BackgroundTask
|
||||
// started tracks whether the task has been activated
|
||||
// mu guards the started flag against concurrent StartIfNeeded / Stop calls.
|
||||
mu sync.Mutex
|
||||
// started tracks whether the task has been activated.
|
||||
// Only mutated while holding mu.
|
||||
started bool
|
||||
// startOnce ensures single initialization
|
||||
startOnce sync.Once
|
||||
}
|
||||
|
||||
// NewLazyBackgroundTask creates a background task that doesn't start immediately.
|
||||
@@ -29,24 +34,28 @@ func NewLazyBackgroundTask(name string, interval time.Duration, taskFunc func(),
|
||||
}
|
||||
|
||||
// StartIfNeeded starts the background task only if it hasn't been started yet.
|
||||
// Uses sync.Once to ensure thread-safe single initialization.
|
||||
// Safe to call concurrently. After Stop has been called this is a no-op;
|
||||
// the task is not restartable.
|
||||
func (lt *LazyBackgroundTask) StartIfNeeded() {
|
||||
lt.startOnce.Do(func() {
|
||||
if !lt.started {
|
||||
lt.BackgroundTask.Start()
|
||||
lt.started = true
|
||||
}
|
||||
})
|
||||
lt.mu.Lock()
|
||||
defer lt.mu.Unlock()
|
||||
if lt.started {
|
||||
return
|
||||
}
|
||||
lt.BackgroundTask.Start()
|
||||
lt.started = true
|
||||
}
|
||||
|
||||
// Stop stops the background task if it was started.
|
||||
// Resets the start state to allow potential future re-initialization.
|
||||
// Once stopped, the task cannot be restarted (see type doc).
|
||||
func (lt *LazyBackgroundTask) Stop() {
|
||||
if lt.started {
|
||||
lt.BackgroundTask.Stop()
|
||||
lt.started = false
|
||||
lt.startOnce = sync.Once{}
|
||||
lt.mu.Lock()
|
||||
defer lt.mu.Unlock()
|
||||
if !lt.started {
|
||||
return
|
||||
}
|
||||
lt.BackgroundTask.Stop()
|
||||
lt.started = false
|
||||
}
|
||||
|
||||
// NewLazyCacheWithLogger creates a cache that doesn't start cleanup until first use.
|
||||
|
||||
+142
-12
@@ -58,13 +58,21 @@ func (mpl MemoryPressureLevel) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// MemoryMonitor provides comprehensive memory monitoring and alerting
|
||||
// MemoryMonitor provides comprehensive memory monitoring and alerting.
|
||||
//
|
||||
// Memory sampling is expensive: runtime.ReadMemStats is a stop-the-world
|
||||
// operation. To keep latency predictable the monitor caches the most recent
|
||||
// sample and only refreshes it when the background ticker fires, when TriggerGC
|
||||
// is invoked, or when a caller explicitly calls Refresh(). GetCurrentStats is a
|
||||
// cheap read of that cached sample.
|
||||
type MemoryMonitor struct {
|
||||
lastGCTime time.Time
|
||||
startTime time.Time
|
||||
lastStats *MemoryStats
|
||||
cachedMemStats runtime.MemStats
|
||||
logger *Logger
|
||||
alertThresholds MemoryAlertThresholds
|
||||
config MemoryMonitorConfig
|
||||
baselineGoroutines int
|
||||
baselineHeap uint64
|
||||
heapGrowthRate float64
|
||||
@@ -84,6 +92,30 @@ type MemoryAlertThresholds struct {
|
||||
GCFrequency float64 // Alert when GC frequency exceeds this per minute
|
||||
}
|
||||
|
||||
// MemoryMonitorConfig configures the memory monitor's scheduling behavior.
|
||||
// Thresholds are kept separate in MemoryAlertThresholds.
|
||||
type MemoryMonitorConfig struct {
|
||||
// Interval between background samples. Must be >= MinMemoryMonitorInterval
|
||||
// (30s). Values below the minimum are clamped when monitoring starts.
|
||||
Interval time.Duration
|
||||
}
|
||||
|
||||
// Default and minimum interval values. The minimum exists because
|
||||
// runtime.ReadMemStats is stop-the-world and hammering it on a hot loop causes
|
||||
// noticeable latency spikes, especially under yaegi.
|
||||
const (
|
||||
DefaultMemoryMonitorInterval = 60 * time.Second
|
||||
MinMemoryMonitorInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
// DefaultMemoryMonitorConfig returns a config with sensible production
|
||||
// defaults.
|
||||
func DefaultMemoryMonitorConfig() MemoryMonitorConfig {
|
||||
return MemoryMonitorConfig{
|
||||
Interval: DefaultMemoryMonitorInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMemoryAlertThresholds returns sensible default alert thresholds
|
||||
func DefaultMemoryAlertThresholds() MemoryAlertThresholds {
|
||||
return MemoryAlertThresholds{
|
||||
@@ -95,35 +127,82 @@ func DefaultMemoryAlertThresholds() MemoryAlertThresholds {
|
||||
}
|
||||
}
|
||||
|
||||
// NewMemoryMonitor creates a new memory monitor
|
||||
// NewMemoryMonitor creates a new memory monitor using default scheduling
|
||||
// configuration. See NewMemoryMonitorWithConfig for full control.
|
||||
func NewMemoryMonitor(logger *Logger, thresholds MemoryAlertThresholds) *MemoryMonitor {
|
||||
return NewMemoryMonitorWithConfig(logger, thresholds, DefaultMemoryMonitorConfig())
|
||||
}
|
||||
|
||||
// NewMemoryMonitorWithConfig creates a new memory monitor with an explicit
|
||||
// scheduling config.
|
||||
//
|
||||
// NOTE: the constructor performs a single runtime.ReadMemStats call to capture
|
||||
// baseline heap / goroutine / GC counters used for leak and growth detection.
|
||||
// This is a one-time stop-the-world cost at startup; all subsequent samples
|
||||
// only happen on the monitoring ticker or on explicit Refresh() calls.
|
||||
func NewMemoryMonitorWithConfig(logger *Logger, thresholds MemoryAlertThresholds, config MemoryMonitorConfig) *MemoryMonitor {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
if config.Interval <= 0 {
|
||||
config.Interval = DefaultMemoryMonitorInterval
|
||||
}
|
||||
|
||||
// One-time initial sample to seed baselines used for growth / leak
|
||||
// detection. All subsequent sampling is gated by the monitoring ticker or
|
||||
// explicit Refresh() calls.
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
return &MemoryMonitor{
|
||||
mm := &MemoryMonitor{
|
||||
logger: logger,
|
||||
startTime: time.Now(),
|
||||
alertThresholds: thresholds,
|
||||
config: config,
|
||||
baselineHeap: memStats.HeapAlloc,
|
||||
baselineGoroutines: runtime.NumGoroutine(),
|
||||
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
|
||||
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
lastGCCount: memStats.NumGC,
|
||||
}
|
||||
mm.cachedMemStats = memStats
|
||||
return mm
|
||||
}
|
||||
|
||||
// GetCurrentStats collects current memory statistics
|
||||
// GetCurrentStats returns the most recently sampled memory statistics.
|
||||
//
|
||||
// This is a cheap cached read: it does NOT call runtime.ReadMemStats. Samples
|
||||
// are refreshed only by the monitoring ticker or by an explicit call to
|
||||
// Refresh(). If no sample has been produced yet, stats derived from the
|
||||
// constructor-time raw sample are returned (with no additional STW cost).
|
||||
func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
|
||||
mm.mu.RLock()
|
||||
stats := mm.lastStats
|
||||
mm.mu.RUnlock()
|
||||
if stats != nil {
|
||||
return stats
|
||||
}
|
||||
return mm.buildStatsFromCache()
|
||||
}
|
||||
|
||||
// Refresh synchronously samples current memory statistics via
|
||||
// runtime.ReadMemStats and updates the cached value. This is the only path
|
||||
// (other than the monitoring ticker and TriggerGC) that pays the stop-the-world
|
||||
// cost. Use it in tests or in callers that explicitly need a fresh sample.
|
||||
func (mm *MemoryMonitor) Refresh() *MemoryStats {
|
||||
return mm.sample()
|
||||
}
|
||||
|
||||
// sample performs a stop-the-world ReadMemStats, updates the cached raw stats,
|
||||
// computes a derived MemoryStats snapshot, and stores it as lastStats.
|
||||
func (mm *MemoryMonitor) sample() *MemoryStats {
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Calculate GC frequency
|
||||
// Calculate GC frequency relative to the previous snapshot.
|
||||
gcFrequency := 0.0
|
||||
mm.mu.RLock()
|
||||
lastStats := mm.lastStats
|
||||
@@ -168,6 +247,7 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
|
||||
mm.updateHeapGrowthTracking(stats)
|
||||
|
||||
mm.mu.Lock()
|
||||
mm.cachedMemStats = memStats
|
||||
mm.lastStats = stats
|
||||
mm.lastGCCount = memStats.NumGC
|
||||
mm.mu.Unlock()
|
||||
@@ -175,6 +255,35 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
|
||||
return stats
|
||||
}
|
||||
|
||||
// buildStatsFromCache constructs a MemoryStats snapshot from the cached raw
|
||||
// runtime.MemStats without issuing a new ReadMemStats call. Used as a fallback
|
||||
// when GetCurrentStats is called before the first sample() has completed.
|
||||
func (mm *MemoryMonitor) buildStatsFromCache() *MemoryStats {
|
||||
mm.mu.RLock()
|
||||
memStats := mm.cachedMemStats
|
||||
mm.mu.RUnlock()
|
||||
|
||||
stats := &MemoryStats{
|
||||
HeapAllocBytes: memStats.HeapAlloc,
|
||||
HeapSysBytes: memStats.HeapSys,
|
||||
HeapIdleBytes: memStats.HeapIdle,
|
||||
HeapInuseBytes: memStats.HeapInuse,
|
||||
HeapReleasedBytes: memStats.HeapReleased,
|
||||
HeapObjects: memStats.HeapObjects,
|
||||
StackInuseBytes: memStats.StackInuse,
|
||||
StackSysBytes: memStats.StackSys,
|
||||
GCSysBytes: memStats.GCSys,
|
||||
NumGoroutines: runtime.NumGoroutine(),
|
||||
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
|
||||
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
GCFrequency: 0.0,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
mm.collectApplicationStats(stats)
|
||||
stats.MemoryPressure = mm.calculateMemoryPressure(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// collectApplicationStats gathers application-specific memory stats
|
||||
func (mm *MemoryMonitor) collectApplicationStats(stats *MemoryStats) {
|
||||
// Get session count from ChunkManager if available
|
||||
@@ -229,7 +338,7 @@ func (mm *MemoryMonitor) updateGoroutineTracking(stats *MemoryStats) {
|
||||
}
|
||||
|
||||
// Check for potential goroutine leak
|
||||
if stats.NumGoroutines > mm.baselineGoroutines+int(mm.alertThresholds.GoroutineCount) {
|
||||
if stats.NumGoroutines > mm.baselineGoroutines+mm.alertThresholds.GoroutineCount {
|
||||
mm.mu.Lock()
|
||||
wasAlert := mm.goroutineLeakAlert
|
||||
if !wasAlert {
|
||||
@@ -302,7 +411,16 @@ var (
|
||||
globalMonitoringMutex sync.Mutex
|
||||
)
|
||||
|
||||
// StartMonitoring starts continuous memory monitoring as a global singleton
|
||||
// StartMonitoring starts continuous memory monitoring as a global singleton.
|
||||
//
|
||||
// The effective interval is resolved as follows:
|
||||
// 1. If the caller passes a positive interval, that is used.
|
||||
// 2. Otherwise the configured MemoryMonitorConfig.Interval is used.
|
||||
// 3. Otherwise the built-in default (60s) is used.
|
||||
//
|
||||
// The result is then clamped to a minimum of MinMemoryMonitorInterval (30s) to
|
||||
// avoid stop-the-world ReadMemStats storms. Callers that need rapid updates in
|
||||
// tests should call Refresh() directly instead of spinning the ticker fast.
|
||||
func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Duration) {
|
||||
globalMonitoringMutex.Lock()
|
||||
defer globalMonitoringMutex.Unlock()
|
||||
@@ -316,7 +434,17 @@ func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Dura
|
||||
}
|
||||
|
||||
if interval <= 0 {
|
||||
interval = 30 * time.Second
|
||||
interval = mm.config.Interval
|
||||
}
|
||||
if interval <= 0 {
|
||||
interval = DefaultMemoryMonitorInterval
|
||||
}
|
||||
if interval < MinMemoryMonitorInterval {
|
||||
if !isTestMode() {
|
||||
mm.logger.Debug("Memory monitor interval %v is below minimum %v; clamping",
|
||||
interval, MinMemoryMonitorInterval)
|
||||
}
|
||||
interval = MinMemoryMonitorInterval
|
||||
}
|
||||
|
||||
registry := GetGlobalTaskRegistry()
|
||||
@@ -325,7 +453,7 @@ func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Dura
|
||||
"memory-monitor",
|
||||
interval,
|
||||
func() {
|
||||
stats := mm.GetCurrentStats()
|
||||
stats := mm.sample()
|
||||
mm.LogMemoryStats(stats)
|
||||
mm.checkAlerts(stats)
|
||||
},
|
||||
@@ -369,14 +497,16 @@ func (mm *MemoryMonitor) checkAlerts(stats *MemoryStats) {
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerGC forces garbage collection and logs the impact
|
||||
// TriggerGC forces garbage collection and logs the impact. Both the before and
|
||||
// after measurements are fresh samples (explicit Refresh() calls) because the
|
||||
// comparison is meaningless against a stale cached snapshot.
|
||||
func (mm *MemoryMonitor) TriggerGC() {
|
||||
before := mm.GetCurrentStats()
|
||||
before := mm.Refresh()
|
||||
|
||||
runtime.GC()
|
||||
runtime.GC() // Run twice to ensure full collection
|
||||
|
||||
after := mm.GetCurrentStats()
|
||||
after := mm.Refresh()
|
||||
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freedBytes := int64(before.HeapAllocBytes) - int64(after.HeapAllocBytes)
|
||||
|
||||
+273
-60
@@ -13,6 +13,99 @@ import (
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// bypassReason describes why a request is being forwarded without OIDC auth.
|
||||
// It is only used for logging and to decide whether extra side-effects
|
||||
// (propagating the user header from an existing session) should run.
|
||||
const (
|
||||
bypassReasonExcluded = "excluded-url"
|
||||
bypassReasonSSE = "sse"
|
||||
bypassReasonWebSocket = "websocket"
|
||||
)
|
||||
|
||||
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
|
||||
// (RFC 6455). The middleware can only see the handshake; once Traefik
|
||||
// completes the upgrade it forwards frames directly, so we never re-process
|
||||
// per-frame traffic. We bypass auth on the handshake the same way we do for
|
||||
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
||||
return false
|
||||
}
|
||||
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
|
||||
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// shouldBypassAuth decides whether a request must skip OIDC authentication
|
||||
// entirely. It returns (true, reason) when either the request path matches a
|
||||
// configured excluded URL, the Accept header asks for a text/event-stream
|
||||
// response (SSE), or the request is a WebSocket upgrade handshake. The
|
||||
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
|
||||
// kind (e.g. propagating user headers).
|
||||
//
|
||||
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
|
||||
// WebSocket traffic is never blocked by a slow/broken provider.
|
||||
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
return true, bypassReasonExcluded
|
||||
}
|
||||
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
|
||||
return true, bypassReasonSSE
|
||||
}
|
||||
if isWebSocketUpgrade(req) {
|
||||
return true, bypassReasonWebSocket
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
|
||||
// requests and, on success, copies the authenticated user's identity onto
|
||||
// the outgoing request so downstream services can see who the user is.
|
||||
//
|
||||
// Returns true when the request carries a valid authenticated session and
|
||||
// the bypass should proceed. Returns false when no usable session is
|
||||
// present; callers must then reject the request (typically with 401) to
|
||||
// prevent unauthenticated traffic from reaching the backend just by setting
|
||||
// `Accept: text/event-stream` or sending a WebSocket upgrade.
|
||||
//
|
||||
// The check is cookie-only: the session cookie is sealed by our encryption
|
||||
// key, so the authenticated flag cannot be forged. We do NOT run full token
|
||||
// signature verification here so that SSE/WS keeps working when the OIDC
|
||||
// provider is briefly unavailable for JWK fetches.
|
||||
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
|
||||
if t.sessionManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
|
||||
return false
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
|
||||
return false
|
||||
}
|
||||
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
}
|
||||
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
|
||||
return true
|
||||
}
|
||||
|
||||
// ServeHTTP implements the main middleware logic for processing HTTP requests.
|
||||
// It handles the complete OIDC authentication flow including:
|
||||
// - Excluded URL bypass
|
||||
@@ -26,6 +119,31 @@ import (
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The incoming HTTP request.
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Log request entry for debugging routing issues
|
||||
t.logger.Debugf("Incoming request: %s %s", req.Method, req.URL.Path)
|
||||
|
||||
// Handle logout requests early - before waiting for OIDC initialization
|
||||
// This allows users to logout even if the OIDC provider is unavailable
|
||||
if req.URL.Path == t.logoutURLPath {
|
||||
t.logger.Debugf("Logout path matched early: %s", req.URL.Path)
|
||||
t.handleLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle backchannel logout (IdP-initiated POST with logout_token)
|
||||
if t.enableBackchannelLogout && t.backchannelLogoutPath != "" && req.URL.Path == t.backchannelLogoutPath {
|
||||
t.logger.Debug("Backchannel logout path matched")
|
||||
t.handleBackchannelLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle front-channel logout (IdP-initiated GET with sid/iss in iframe)
|
||||
if t.enableFrontchannelLogout && t.frontchannelLogoutPath != "" && req.URL.Path == t.frontchannelLogoutPath {
|
||||
t.logger.Debug("Front-channel logout path matched")
|
||||
t.handleFrontchannelLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.URL.Path, "/health") {
|
||||
t.firstRequestMutex.Lock()
|
||||
if !t.firstRequestReceived {
|
||||
@@ -42,6 +160,43 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded
|
||||
// URLs, SSE and WebSocket upgrade requests must not block on provider
|
||||
// init. For SSE/WebSocket we ALSO require an authenticated session
|
||||
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
|
||||
// of in-flight streams can't follow an OIDC redirect, so forwarding
|
||||
// unauthenticated traffic would silently expose the backend.
|
||||
if bypass, reason := t.shouldBypassAuth(req); bypass {
|
||||
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
|
||||
switch reason {
|
||||
case bypassReasonExcluded:
|
||||
// Operator-declared excluded URLs forward unconditionally.
|
||||
t.next.ServeHTTP(rw, req)
|
||||
case bypassReasonSSE, bypassReasonWebSocket:
|
||||
// Skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but still require an authenticated session.
|
||||
// Otherwise an unauthenticated client could hit the backend
|
||||
// just by setting Accept: text/event-stream or sending a
|
||||
// WebSocket upgrade.
|
||||
if !t.applyBypassUserHeaders(req, reason) {
|
||||
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
default:
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Log waiting for initialization to help diagnose hanging requests
|
||||
t.logger.Debug("Waiting for OIDC provider initialization...")
|
||||
|
||||
// time.NewTimer + Stop avoids leaking a goroutine+channel for 30s on every
|
||||
// request when initComplete fires quickly (would happen with time.After).
|
||||
initTimer := time.NewTimer(30 * time.Second)
|
||||
defer initTimer.Stop()
|
||||
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
// Read issuerURL with RLock
|
||||
@@ -72,24 +227,13 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.logger.Debug("Request canceled while waiting for OIDC initialization")
|
||||
t.sendErrorResponse(rw, req, "Request canceled", http.StatusRequestTimeout)
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
case <-initTimer.C:
|
||||
t.logger.Error("Timeout waiting for OIDC initialization")
|
||||
t.sendErrorResponse(rw, req, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/event-stream") {
|
||||
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Bypass checks already ran before the init wait; no need to repeat them.
|
||||
t.sessionManager.CleanupOldCookies(rw, req)
|
||||
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
@@ -107,6 +251,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.sendErrorResponse(rw, req, "Critical session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Sub-resource requests (script/image/fetch/serviceWorker) must not
|
||||
// trigger an OIDC redirect from this path either: they would overwrite
|
||||
// any in-flight CSRF/nonce in the session. Let the next HTML navigation
|
||||
// initiate the flow. See issue #129.
|
||||
if t.isAjaxRequest(req) || t.isNonNavigationRequest(req) {
|
||||
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
scheme := utils.DetermineScheme(req, t.forceHTTPS)
|
||||
host := utils.DetermineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
@@ -120,14 +272,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
host := utils.DetermineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
if req.URL.Path == t.logoutURLPath {
|
||||
t.handleLogout(rw, req)
|
||||
return
|
||||
}
|
||||
// Check if the current request is the OIDC callback
|
||||
t.logger.Debugf("Checking callback URL match: request_path=%q, configured_callback=%q", req.URL.Path, t.redirURLPath)
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
t.logger.Debugf("Callback URL matched, processing OIDC callback (redirect_url=%s)", redirectURL)
|
||||
t.handleCallback(rw, req, redirectURL)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Callback URL did not match (request_path=%q != configured=%q), continuing auth flow", req.URL.Path, t.redirURLPath)
|
||||
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
||||
|
||||
@@ -137,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
@@ -160,8 +312,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshTokenPresent := session.GetRefreshToken() != ""
|
||||
|
||||
// Check if this is an AJAX request that should receive 401 instead of redirect
|
||||
isAjaxRequest := t.isAjaxRequest(req)
|
||||
// Decide whether to answer with 401 instead of a redirect. AJAX requests
|
||||
// cannot follow a 302 into an IdP, and sub-resource loads (script/image/
|
||||
// fetch/serviceWorker) must not trigger a fresh OIDC flow because parallel
|
||||
// loads would each overwrite the session CSRF/nonce (issue #129). Only
|
||||
// top-level HTML navigations should redirect.
|
||||
isAjaxRequest := t.isAjaxRequest(req) || t.isNonNavigationRequest(req)
|
||||
|
||||
// Check if refresh token is likely expired (older than 6 hours)
|
||||
refreshTokenExpired := refreshTokenPresent && t.isRefreshTokenExpired(session)
|
||||
@@ -205,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
|
||||
userIdentifier = session.GetUserIdentifier()
|
||||
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
@@ -255,40 +411,79 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Info("No email found in session during final processing, initiating re-auth")
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
|
||||
// Reset redirect count to prevent loops when session is invalid
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
tokenForClaims := session.GetIDToken()
|
||||
if tokenForClaims == "" {
|
||||
tokenForClaims = session.GetAccessToken()
|
||||
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
// Reset redirect count to prevent loops when token is missing
|
||||
// Check if session has been invalidated via backchannel or front-channel logout
|
||||
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
|
||||
idToken := session.GetIDToken()
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
|
||||
// Clear the session and redirect to login
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
}
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve ID-token claims at most once per request. SessionData caches
|
||||
// the parsed claims keyed on the raw ID token, so concurrent dashboard
|
||||
// panel requests on the same session don't repeatedly base64-decode and
|
||||
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
|
||||
// that hosts Traefik plugins). idClaims is reused below by the
|
||||
// header-templates branch.
|
||||
idToken := session.GetIDToken()
|
||||
var (
|
||||
idClaims map[string]interface{}
|
||||
idClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
}
|
||||
|
||||
// Choose which claims drive groups/roles extraction. Prefer the ID
|
||||
// token (cached) and fall back to the access token if there is no ID
|
||||
// token in the session — matching the prior behavior for opaque
|
||||
// ID-token providers.
|
||||
var (
|
||||
groupClaims map[string]interface{}
|
||||
groupClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
groupClaims, groupClaimsErr = idClaims, idClaimsErr
|
||||
} else if accessToken := session.GetAccessToken(); accessToken != "" {
|
||||
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
|
||||
} else if len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
var groups, roles []string
|
||||
|
||||
if groupClaimsErr == nil && groupClaims != nil {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRolesFromClaims(groupClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize empty slices
|
||||
var groups, roles []string
|
||||
|
||||
if tokenForClaims != "" {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
// Reset redirect count to prevent loops when claim extraction fails
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
} else if err == nil {
|
||||
if err == nil {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
@@ -307,51 +502,53 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
if idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.headerTemplates) > 0 {
|
||||
claims, err := t.extractClaimsFunc(session.GetIDToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
|
||||
if idClaimsErr != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", idClaimsErr)
|
||||
} else {
|
||||
// idClaims may be nil when no ID token is present; templates
|
||||
// referencing .Claims.* will simply produce empty values, which
|
||||
// matches the prior behavior.
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": session.GetAccessToken(),
|
||||
"IDToken": session.GetIDToken(),
|
||||
"IDToken": idToken,
|
||||
"RefreshToken": session.GetRefreshToken(),
|
||||
"Claims": claims,
|
||||
"Claims": idClaims,
|
||||
}
|
||||
|
||||
for headerName, tmpl := range t.headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := tmpl.Execute(&buf, templateData); err != nil {
|
||||
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
|
||||
continue
|
||||
}
|
||||
headerValue := buf.String()
|
||||
|
||||
req.Header.Set(headerName, headerValue)
|
||||
|
||||
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
||||
}
|
||||
session.MarkDirty()
|
||||
t.logger.Debugf("Session marked dirty after templated header processing.")
|
||||
// NOTE: templates only mutate request headers (not session state),
|
||||
// so we deliberately do NOT MarkDirty / Save here. Previously every
|
||||
// authenticated request with header templates re-encrypted and
|
||||
// rewrote all session cookies, which was a measurable CPU and
|
||||
// Set-Cookie tax on dashboards that poll many panels per second.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -374,7 +571,23 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
}
|
||||
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
||||
// Strip OIDC session cookies before forwarding to the backend to prevent
|
||||
// HTTP 431 "Request Header Fields Too Large" errors (GitHub issue #122).
|
||||
if t.stripAuthCookies {
|
||||
prefix := t.sessionManager.GetCookiePrefix()
|
||||
filtered := make([]*http.Cookie, 0, len(req.Cookies()))
|
||||
for _, c := range req.Cookies() {
|
||||
if !strings.HasPrefix(c.Name, prefix) {
|
||||
filtered = append(filtered, c)
|
||||
}
|
||||
}
|
||||
req.Header.Del("Cookie")
|
||||
for _, c := range filtered {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
}
|
||||
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -95,6 +95,38 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogoutWorksWithoutOIDCInitialization tests that logout works even if OIDC provider is unavailable
|
||||
// This is critical for allowing users to clear their session when the provider is down
|
||||
func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
logoutURLPath: "/logout",
|
||||
postLogoutRedirectURI: "/",
|
||||
forceHTTPS: false,
|
||||
}
|
||||
// Note: initComplete is NOT closed, simulating OIDC provider being unavailable
|
||||
|
||||
req := httptest.NewRequest("GET", "/logout", nil)
|
||||
req.Host = "example.com"
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
// Should redirect to post-logout URI even without OIDC initialization
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected redirect (302) for logout, got %d", rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location == "" {
|
||||
t.Error("Expected Location header for logout redirect")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddlewareDomainRestrictions tests domain-based access control
|
||||
// NOTE: Currently commented out due to complex session setup requirements
|
||||
// These scenarios are tested indirectly through integration tests
|
||||
@@ -129,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create authenticated session
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetIDToken("dummy-token")
|
||||
session.Save(req, httptest.NewRecorder())
|
||||
@@ -171,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create session with forbidden domain
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@forbidden.com")
|
||||
session.SetUserIdentifier("user@forbidden.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save and inject cookies
|
||||
@@ -220,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
// Create session with opaque token
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
@@ -259,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("") // No email
|
||||
session.SetUserIdentifier("") // No email
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -289,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("") // No ID token
|
||||
session.SetAccessToken("") // No access token
|
||||
|
||||
@@ -317,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -351,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
testEmail := "user@example.com"
|
||||
session.SetEmail(testEmail)
|
||||
session.SetUserIdentifier(testEmail)
|
||||
session.SetIDToken("dummy-id-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
+34
-44
@@ -18,7 +18,6 @@ type RefreshCoordinator struct {
|
||||
inFlightRefreshes map[string]*refreshOperation
|
||||
cleanupTimers map[string]*time.Timer
|
||||
sessionRefreshAttempts map[string]*refreshAttemptTracker
|
||||
delayedCleanupQueue chan delayedCleanupItem
|
||||
circuitBreaker *RefreshCircuitBreaker
|
||||
metrics *RefreshMetrics
|
||||
logger *Logger
|
||||
@@ -107,12 +106,6 @@ type RefreshMetrics struct {
|
||||
currentInFlightRefreshes int32
|
||||
}
|
||||
|
||||
// delayedCleanupItem represents an item scheduled for delayed cleanup
|
||||
type delayedCleanupItem struct {
|
||||
cleanupAt time.Time
|
||||
tokenHash string
|
||||
}
|
||||
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
|
||||
type RefreshCircuitBreaker struct {
|
||||
lastFailureTime time.Time
|
||||
@@ -143,7 +136,6 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
delayedCleanupQueue: make(chan delayedCleanupItem, 1000), // Buffered channel for cleanup items
|
||||
cleanupTimers: make(map[string]*time.Timer),
|
||||
circuitBreaker: &RefreshCircuitBreaker{
|
||||
config: RefreshCircuitBreakerConfig{
|
||||
@@ -158,10 +150,6 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
|
||||
rc.wg.Add(1)
|
||||
go rc.cleanupRoutine()
|
||||
|
||||
// Start delayed cleanup processor (single goroutine processes all cleanup timers)
|
||||
rc.wg.Add(1)
|
||||
go rc.processDelayedCleanups()
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
@@ -234,7 +222,7 @@ func (rc *RefreshCoordinator) CoordinateRefresh(
|
||||
// Returns (operation, false, nil) if joined an existing operation
|
||||
// Returns (nil, false, error) if the operation was rejected
|
||||
func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
ctx context.Context,
|
||||
_ context.Context,
|
||||
sessionID string,
|
||||
tokenHash string,
|
||||
refreshToken string,
|
||||
@@ -293,7 +281,7 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
// executeRefreshAsync performs the actual refresh operation asynchronously
|
||||
func (rc *RefreshCoordinator) executeRefreshAsync(
|
||||
operation *refreshOperation,
|
||||
sessionID string,
|
||||
_ string, // sessionID - reserved for future metrics/logging
|
||||
tokenHash string,
|
||||
refreshFunc func() (*TokenResponse, error),
|
||||
) {
|
||||
@@ -377,35 +365,19 @@ func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) {
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
}
|
||||
|
||||
// performCleanup removes the operation from the in-flight map
|
||||
// performCleanup removes the operation from the in-flight map.
|
||||
// Idempotent: only decrements the in-flight counter if an entry was actually
|
||||
// removed. This guards against any future path accidentally calling cleanup
|
||||
// twice for the same tokenHash (which would corrupt the refresh budget).
|
||||
func (rc *RefreshCoordinator) performCleanup(tokenHash string) {
|
||||
rc.refreshMutex.Lock()
|
||||
delete(rc.inFlightRefreshes, tokenHash)
|
||||
_, existed := rc.inFlightRefreshes[tokenHash]
|
||||
if existed {
|
||||
delete(rc.inFlightRefreshes, tokenHash)
|
||||
}
|
||||
rc.refreshMutex.Unlock()
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
}
|
||||
|
||||
// processDelayedCleanups processes delayed cleanup requests from the queue
|
||||
// This is a single goroutine that handles all delayed cleanups
|
||||
func (rc *RefreshCoordinator) processDelayedCleanups() {
|
||||
defer rc.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case item := <-rc.delayedCleanupQueue:
|
||||
// Wait until cleanup time
|
||||
waitDuration := time.Until(item.cleanupAt)
|
||||
if waitDuration > 0 {
|
||||
select {
|
||||
case <-time.After(waitDuration):
|
||||
case <-rc.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
rc.performCleanup(item.tokenHash)
|
||||
case <-rc.stopChan:
|
||||
return
|
||||
}
|
||||
if existed {
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -494,15 +466,33 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
|
||||
return refreshCoordinatorSessionID(token)
|
||||
}
|
||||
|
||||
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
|
||||
// for both deduplication and per-session attempt tracking. Using sha256 of the
|
||||
// raw token means each rotation produces a fresh sessionID with its own attempt
|
||||
// budget, which is what we want.
|
||||
func refreshCoordinatorSessionID(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure
|
||||
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
|
||||
// coordinated refresh result. It is wider than RefreshTimeout so a follower
|
||||
// always sees the leader's result instead of timing out independently.
|
||||
const refreshCoordinatorWaitTimeout = 35 * time.Second
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure by
|
||||
// consulting the global memory monitor. Returns true when pressure reaches
|
||||
// High or Critical, at which point we refuse new refresh operations to
|
||||
// avoid aggravating an already-stressed heap.
|
||||
func (rc *RefreshCoordinator) isUnderMemoryPressure() bool {
|
||||
// This is a simplified check - in production you'd want to use runtime.MemStats
|
||||
// or system-specific memory monitoring
|
||||
return false // Placeholder - implement actual memory check
|
||||
monitor := GetGlobalMemoryMonitor()
|
||||
if monitor == nil {
|
||||
return false
|
||||
}
|
||||
return monitor.GetMemoryPressure() >= MemoryPressureHigh
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically cleans up stale tracking entries
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// stubTokenExchanger lets us count how many upstream refresh-token grants
|
||||
// happen for a given refresh_token across concurrent middleware-level calls.
|
||||
type stubTokenExchanger struct {
|
||||
calls int32
|
||||
delay time.Duration
|
||||
resp *TokenResponse
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.delay > 0 {
|
||||
time.Sleep(s.delay)
|
||||
}
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many
|
||||
// concurrent calls to coordinatedTokenRefresh with the same refresh token
|
||||
// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call.
|
||||
//
|
||||
// Without the wireup this assertion fails (one upstream call per goroutine).
|
||||
func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 100 * time.Millisecond,
|
||||
resp: &TokenResponse{
|
||||
AccessToken: "new_access",
|
||||
RefreshToken: "new_refresh",
|
||||
IDToken: "new_id",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const concurrency = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "new_access" {
|
||||
t.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
got := atomic.LoadInt32(&stub.calls)
|
||||
// Up to 2 is acceptable to absorb the documented timing slack in the
|
||||
// existing coordinator tests (e.g. operation just cleaned up before a
|
||||
// late goroutine reads the in-flight map). Anything beyond that means
|
||||
// coalescing is broken.
|
||||
if got > 2 {
|
||||
t.Fatalf("expected <=2 upstream refresh calls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil
|
||||
// coordinator path so existing tests that build TraefikOidc literals stay
|
||||
// green.
|
||||
func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenExchanger: stub,
|
||||
// refreshCoordinator deliberately nil
|
||||
}
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, "rt")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "ok" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected exactly 1 upstream call, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that
|
||||
// distinct refresh tokens are not falsely coalesced.
|
||||
func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 20 * time.Millisecond,
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
cfg.DeduplicationCleanupDelay = 0
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const distinct = 8
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(distinct)
|
||||
for i := 0; i < distinct; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i))))
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(&stub.calls); int(got) != distinct {
|
||||
t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// inMemoryCache is the smallest CacheInterface that satisfies the cross-
|
||||
// replica dedup contract: Set/Get with TTL. Used in place of the universal
|
||||
// cache singleton so these tests stay hermetic.
|
||||
type inMemoryCache struct {
|
||||
entries map[string]inMemoryCacheEntry
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type inMemoryCacheEntry struct {
|
||||
expiresAt time.Time
|
||||
value interface{}
|
||||
}
|
||||
|
||||
func newInMemoryCache() *inMemoryCache {
|
||||
return &inMemoryCache{entries: make(map[string]inMemoryCacheEntry)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries[key] = inMemoryCacheEntry{value: value, expiresAt: time.Now().Add(ttl)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Get(key string) (any, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
e, ok := c.entries[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(e.expiresAt) {
|
||||
delete(c.entries, key)
|
||||
return nil, false
|
||||
}
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.entries, key)
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) SetMaxSize(int) {}
|
||||
func (c *inMemoryCache) Cleanup() {}
|
||||
func (c *inMemoryCache) Close() {}
|
||||
func (c *inMemoryCache) Size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.entries)
|
||||
}
|
||||
func (c *inMemoryCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries = map[string]inMemoryCacheEntry{}
|
||||
}
|
||||
func (c *inMemoryCache) GetStats() map[string]any { return map[string]any{} }
|
||||
|
||||
// erroringTokenExchanger always errors - simulates an IdP rejection.
|
||||
type erroringTokenExchanger struct {
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&e.calls, 1)
|
||||
return nil, errors.New("invalid_grant")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) RevokeTokenWithProvider(_, _ string) error { return nil }
|
||||
|
||||
// TestCoordinatedTokenRefresh_CrossReplicaCacheHit simulates a peer Traefik
|
||||
// replica having just refreshed: the shared cache already has the result, so
|
||||
// this pod must reuse it without ever calling the IdP.
|
||||
func TestCoordinatedTokenRefresh_CrossReplicaCacheHit(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "should_not_be_called"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
preExisting := &TokenResponse{
|
||||
AccessToken: "from_peer",
|
||||
RefreshToken: "rotated_by_peer",
|
||||
IDToken: "id_from_peer",
|
||||
}
|
||||
rt := "shared_refresh_token"
|
||||
cache.Set(refreshResultCacheKey(refreshCoordinatorSessionID(rt)), preExisting, refreshResultCacheTTL)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(httptest.NewRequest("GET", "/", nil), rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "from_peer" {
|
||||
t.Fatalf("expected peer-provided response, got %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 0 {
|
||||
t.Fatalf("expected 0 upstream calls (peer already refreshed), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache verifies that on a
|
||||
// cache miss the leader stores its result for peers to find within the TTL.
|
||||
func TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "fresh_grant"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
rt := "fresh_refresh_token"
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected 1 upstream call, got %d", got)
|
||||
}
|
||||
|
||||
v, ok := cache.Get(refreshResultCacheKey(refreshCoordinatorSessionID(rt)))
|
||||
if !ok {
|
||||
t.Fatal("expected refresh result to be cached after upstream success")
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); !ok || tr.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("cached value malformed: %+v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_ErrorIsNotCached makes sure we don't poison the
|
||||
// dedup cache when the IdP rejects the grant. Peers must run their own
|
||||
// refresh; they cannot inherit an error.
|
||||
func TestCoordinatedTokenRefresh_ErrorIsNotCached(t *testing.T) {
|
||||
failing := &erroringTokenExchanger{}
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: failing,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
if _, err := oidc.coordinatedTokenRefresh(nil, "doomed_refresh_token"); err == nil {
|
||||
t.Fatal("expected an error from the failing exchanger")
|
||||
}
|
||||
if cache.Size() != 0 {
|
||||
t.Fatalf("error result must not be cached, size=%d", cache.Size())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// sessionWithIssuedAt builds the smallest SessionData that GetRefreshTokenIssuedAt
|
||||
// reads from. We can't reuse sessionPool.Get() here because that requires a
|
||||
// fully initialized SessionManager - overkill for this unit-level check.
|
||||
func sessionWithIssuedAt(t *testing.T, issuedAt time.Time) *SessionData {
|
||||
t.Helper()
|
||||
rs := sessions.NewSession(nil, "refresh")
|
||||
if !issuedAt.IsZero() {
|
||||
rs.Values["issued_at"] = issuedAt.Unix()
|
||||
}
|
||||
return &SessionData{
|
||||
refreshSession: rs,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
idTokenChunks: make(map[int]*sessions.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_DisabledWhenAgeZero(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 0}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-30*24*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when maxRefreshTokenAge is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Time{}) // no issued_at value
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when issued_at missing (legacy session)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_WithinWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-1*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false within max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_BeyondWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-7*time.Hour))
|
||||
if !tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=true beyond max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_NilGuards(t *testing.T) {
|
||||
var tr *TraefikOidc
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil receiver must not panic and must return false")
|
||||
}
|
||||
tr = &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil session must return false")
|
||||
}
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
|
||||
// Simulate successful Azure authentication
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Azure may use opaque access tokens
|
||||
session.SetAccessToken("opaque-azure-access-token")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
|
||||
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
|
||||
assert.Equal(t, "user@example.com", session2.GetEmail())
|
||||
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
|
||||
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
|
||||
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
|
||||
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
|
||||
|
||||
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
|
||||
// Set up the attacker's session with malicious data
|
||||
attackerSession.SetAuthenticated(true)
|
||||
attackerSession.SetEmail("attacker@evil.com")
|
||||
attackerSession.SetUserIdentifier("attacker@evil.com")
|
||||
attackerSession.SetIDToken(ValidIDToken)
|
||||
attackerSession.SetAccessToken(ValidAccessToken)
|
||||
|
||||
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get the email from the session
|
||||
email := session.GetEmail()
|
||||
email := session.GetUserIdentifier()
|
||||
w.Header().Set("X-User-Email", email)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
+90
-42
@@ -100,7 +100,7 @@ type combinedSessionPayload struct {
|
||||
A string `json:"a,omitempty"`
|
||||
R string `json:"r,omitempty"`
|
||||
I string `json:"i,omitempty"`
|
||||
E string `json:"e,omitempty"`
|
||||
Ui string `json:"ui,omitempty"`
|
||||
Cs string `json:"cs,omitempty"`
|
||||
N string `json:"n,omitempty"`
|
||||
Cv string `json:"cv,omitempty"`
|
||||
@@ -113,11 +113,11 @@ type combinedSessionPayload struct {
|
||||
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
|
||||
// All other mainSession.Values keys are stored in the X (extra) field.
|
||||
var knownSessionKeys = map[string]bool{
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"email": true,
|
||||
"authenticated": true,
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"user_identifier": true,
|
||||
"authenticated": true,
|
||||
"csrf": true,
|
||||
"nonce": true,
|
||||
"code_verifier": true,
|
||||
@@ -164,7 +164,7 @@ func decompressCombinedPayload(compressed string) (*combinedSessionPayload, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gr.Close()
|
||||
defer func() { _ = gr.Close() }()
|
||||
|
||||
// Limit decompressed size to prevent zip bombs
|
||||
limitedReader := io.LimitReader(gr, 512*1024) // 512KB max
|
||||
@@ -500,6 +500,11 @@ func (sm *SessionManager) combinedChunkCookieName(chunkIndex int) string {
|
||||
return fmt.Sprintf("%s_%d", sm.combinedCookieName(), chunkIndex)
|
||||
}
|
||||
|
||||
// GetCookiePrefix returns the cookie prefix used for all OIDC session cookies.
|
||||
func (sm *SessionManager) GetCookiePrefix() string {
|
||||
return sm.cookiePrefix
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the SessionManager and all its background tasks
|
||||
func (sm *SessionManager) Shutdown() error {
|
||||
var shutdownErr error
|
||||
@@ -1129,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *
|
||||
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
|
||||
|
||||
// Populate legacy session values from combined payload
|
||||
sessionData.mainSession.Values["email"] = payload.E
|
||||
sessionData.mainSession.Values["user_identifier"] = payload.Ui
|
||||
sessionData.mainSession.Values["authenticated"] = payload.Au
|
||||
sessionData.mainSession.Values["csrf"] = payload.Cs
|
||||
sessionData.mainSession.Values["nonce"] = payload.N
|
||||
@@ -1211,6 +1216,18 @@ type SessionData struct {
|
||||
dirty bool
|
||||
|
||||
inUse bool
|
||||
|
||||
// cachedClaimsToken is the ID token string whose claims were last parsed and
|
||||
// cached. A lazy, per-request cache to avoid re-parsing the JWT on every
|
||||
// authenticated request (e.g. for headerTemplates). Protected by sessionMutex.
|
||||
cachedClaimsToken string
|
||||
|
||||
// cachedClaims holds the parsed claims for cachedClaimsToken.
|
||||
cachedClaims map[string]interface{}
|
||||
|
||||
// cachedClaimsErr holds the parse error (if any) for cachedClaimsToken so
|
||||
// failures are not retried within the same request.
|
||||
cachedClaimsErr error
|
||||
}
|
||||
|
||||
// IsDirty returns true if the session data has been modified since it was last loaded or saved.
|
||||
@@ -1261,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti
|
||||
A: sd.getAccessTokenUnsafe(),
|
||||
R: sd.getRefreshTokenUnsafe(),
|
||||
I: sd.getIDTokenUnsafe(),
|
||||
E: sd.getEmailUnsafe(),
|
||||
Ui: sd.getUserIdentifierUnsafe(),
|
||||
Au: sd.getAuthenticatedUnsafe(),
|
||||
Cs: sd.getCSRFUnsafe(),
|
||||
N: sd.getNonceUnsafe(),
|
||||
@@ -1548,9 +1565,10 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
}()
|
||||
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
sd.clearAllSessionData(r, true)
|
||||
|
||||
// Release the lock before calling Save to prevent deadlock
|
||||
sd.sessionMutex.Unlock()
|
||||
|
||||
// This is primarily for testing - in production w will often be nil
|
||||
var err error
|
||||
@@ -1588,7 +1606,7 @@ func (sd *SessionData) returnToPoolSafely() {
|
||||
// Parameters:
|
||||
// - r: The HTTP request context.
|
||||
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
func (sd *SessionData) clearTokenChunks(_ *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
clearSessionValues(session, true)
|
||||
}
|
||||
@@ -1731,6 +1749,12 @@ func (sd *SessionData) Reset() {
|
||||
sd.request = nil
|
||||
sd.useCombinedStorage = true // Reset to use combined storage by default
|
||||
|
||||
// Drop any cached claims so pooled SessionData does not leak claim data
|
||||
// between requests/users.
|
||||
sd.cachedClaimsToken = ""
|
||||
sd.cachedClaims = nil
|
||||
sd.cachedClaimsErr = nil
|
||||
|
||||
// Reset the refresh mutex to ensure clean state
|
||||
// Note: We don't need to lock it since sessionMutex is already held
|
||||
// and this session is not in use by any request
|
||||
@@ -1820,23 +1844,12 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
if token != "" {
|
||||
dotCount := strings.Count(token, ".")
|
||||
// Reject tokens with exactly 1 dot (invalid format - neither JWT nor opaque)
|
||||
if dotCount == 1 {
|
||||
if sd.manager != nil && sd.manager.logger != nil {
|
||||
sd.manager.logger.Debug("Invalid token format during storage (dots: %d) - rejecting", dotCount)
|
||||
}
|
||||
return
|
||||
}
|
||||
// For opaque tokens (no dots), ensure minimum length for security
|
||||
if dotCount == 0 && len(token) < 20 {
|
||||
if len(token) < 20 {
|
||||
if sd.manager != nil && sd.manager.logger != nil {
|
||||
sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token))
|
||||
}
|
||||
return
|
||||
}
|
||||
// Tokens with 2 dots are JWTs, tokens with 0 dots are opaque
|
||||
// Both are valid formats
|
||||
}
|
||||
|
||||
currentAccessToken := sd.getAccessTokenUnsafe()
|
||||
@@ -2456,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address.
|
||||
// The email is extracted from ID token claims and used for
|
||||
// authorization decisions and header injection.
|
||||
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
|
||||
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
|
||||
// upn, preferred_username, etc.). The value is used for authorization
|
||||
// decisions and header injection.
|
||||
// Returns:
|
||||
// - The user's email address string, or an empty string if not set.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
// - The user identifier string, or an empty string if not set.
|
||||
func (sd *SessionData) GetUserIdentifier() string {
|
||||
sd.sessionMutex.RLock()
|
||||
defer sd.sessionMutex.RUnlock()
|
||||
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// SetEmail stores the authenticated user's email address.
|
||||
// The email is typically extracted from the 'email' claim in the ID token.
|
||||
// SetUserIdentifier stores the authenticated user's identifier value.
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
|
||||
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
currentVal, _ := sd.mainSession.Values["email"].(string)
|
||||
if currentVal != email {
|
||||
sd.mainSession.Values["email"] = email
|
||||
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
if currentVal != userIdentifier {
|
||||
sd.mainSession.Values["user_identifier"] = userIdentifier
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
@@ -2519,6 +2532,41 @@ func (sd *SessionData) GetIDToken() string {
|
||||
return sd.getIDTokenUnsafe()
|
||||
}
|
||||
|
||||
// GetIDTokenClaims returns claims parsed from the current ID token, caching
|
||||
// the result on the SessionData so repeated callers within the same request
|
||||
// do not re-parse the JWT. The cache is keyed on the ID token string and is
|
||||
// cleared when the SessionData is reset (see Reset) or when the ID token
|
||||
// changes (e.g. after a refresh).
|
||||
//
|
||||
// The parser parameter is typically the TraefikOidc.extractClaimsFunc, which
|
||||
// lets tests inject mocks just like the direct call it replaces.
|
||||
//
|
||||
// Returns an empty claims map and a nil error when the session has no ID
|
||||
// token, matching the existing "no-op" behavior of the caller sites.
|
||||
func (sd *SessionData) GetIDTokenClaims(parser func(string) (map[string]interface{}, error)) (map[string]interface{}, error) {
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
token := sd.getIDTokenUnsafe()
|
||||
if token == "" {
|
||||
// Invalidate any stale cache without running the parser.
|
||||
sd.cachedClaimsToken = ""
|
||||
sd.cachedClaims = nil
|
||||
sd.cachedClaimsErr = nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if sd.cachedClaimsToken == token && (sd.cachedClaims != nil || sd.cachedClaimsErr != nil) {
|
||||
return sd.cachedClaims, sd.cachedClaimsErr
|
||||
}
|
||||
|
||||
claims, err := parser(token)
|
||||
sd.cachedClaimsToken = token
|
||||
sd.cachedClaims = claims
|
||||
sd.cachedClaimsErr = err
|
||||
return claims, err
|
||||
}
|
||||
|
||||
// getIDTokenUnsafe retrieves the ID token without acquiring locks.
|
||||
// Enhanced ID token retrieval with comprehensive integrity checks and chunking support.
|
||||
// Used when the session mutex is already held to prevent deadlocks.
|
||||
@@ -2578,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string {
|
||||
return result.Token
|
||||
}
|
||||
|
||||
// getEmailUnsafe retrieves the email without acquiring locks.
|
||||
func (sd *SessionData) getEmailUnsafe() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
|
||||
func (sd *SessionData) getUserIdentifierUnsafe() string {
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
|
||||
|
||||
@@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() {
|
||||
s.False(session.IsDirty())
|
||||
}
|
||||
|
||||
// TestSessionData_SetEmail tests email setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetEmail() {
|
||||
// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
session, err := s.sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Set email
|
||||
session.SetEmail("test@example.com")
|
||||
s.Equal("test@example.com", session.GetEmail())
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.Equal("test@example.com", session.GetUserIdentifier())
|
||||
s.True(session.IsDirty())
|
||||
}
|
||||
|
||||
@@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() {
|
||||
// Set some data
|
||||
err = session.SetAuthenticated(true)
|
||||
s.Require().NoError(err)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetCSRF("csrf-token")
|
||||
|
||||
// Clear session
|
||||
@@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() {
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Modify session
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.True(session.IsDirty())
|
||||
|
||||
// Save session
|
||||
|
||||
@@ -926,6 +926,8 @@ func (cm *ChunkManager) detectRepeatedCharacters(token string, config TokenConfi
|
||||
//
|
||||
// Returns:
|
||||
// - An error if the token is expired or has invalid expiration, nil if valid.
|
||||
//
|
||||
//nolint:unparam // error return kept for API consistency and future use
|
||||
func (cm *ChunkManager) validateTokenExpiration(token string, config TokenConfig) error {
|
||||
if !strings.Contains(token, ".") {
|
||||
return nil
|
||||
|
||||
+6
-6
@@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Set up initial session state (what user has when first logging in)
|
||||
session1.SetAuthenticated(true)
|
||||
session1.SetEmail(originalUserData["email"].(string))
|
||||
session1.SetUserIdentifier(originalUserData["email"].(string))
|
||||
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
|
||||
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
|
||||
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
|
||||
@@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
// Simulate what happens when middleware detects expired tokens
|
||||
// It should preserve session state while attempting token refresh
|
||||
originalAuth := session2.GetAuthenticated()
|
||||
originalEmail := session2.GetEmail()
|
||||
originalEmail := session2.GetUserIdentifier()
|
||||
|
||||
// Reconstruct user data from individual stored keys
|
||||
originalUserDataStored := make(map[string]interface{})
|
||||
@@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Verify all session data is still intact after token refresh
|
||||
postRefreshAuth := session2.GetAuthenticated()
|
||||
postRefreshEmail := session2.GetEmail()
|
||||
postRefreshEmail := session2.GetUserIdentifier()
|
||||
userDataPresent := true
|
||||
for k := range originalUserData {
|
||||
if session2.mainSession.Values["user_data_"+k] == nil {
|
||||
@@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with specific creation time
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
|
||||
|
||||
// Create tokens with specific expiry
|
||||
@@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with data that should be preserved or removed
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("cleanup@example.com")
|
||||
session.SetUserIdentifier("cleanup@example.com")
|
||||
|
||||
session.mainSession.Values["user_data"] = "Test User|user-123"
|
||||
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
|
||||
@@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
if scenario.shouldCleanup {
|
||||
if sessionTooOld {
|
||||
session.SetAuthenticated(false)
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
for key := range session.mainSession.Values {
|
||||
|
||||
+81
-116
@@ -1,6 +1,7 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -54,6 +55,15 @@ type Config struct {
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
|
||||
// a stored refresh token. Once the token has been in the session longer
|
||||
// than this, requests treat it as expired up-front - returning 401 to
|
||||
// AJAX callers and triggering full re-auth on navigations - instead of
|
||||
// hammering the IdP with grants that will only fail with invalid_grant.
|
||||
// IdPs do not expose RT TTL on the wire, so this is intentionally a
|
||||
// conservative heuristic; tune to match your provider configuration.
|
||||
// Default 21600 (6h). Set to 0 to disable the check.
|
||||
MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"`
|
||||
SessionMaxAge int `json:"sessionMaxAge"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
@@ -65,6 +75,51 @@ type Config struct {
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
|
||||
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
|
||||
StripAuthCookies bool `json:"stripAuthCookies,omitempty"`
|
||||
EnableBackchannelLogout bool `json:"enableBackchannelLogout,omitempty"`
|
||||
EnableFrontchannelLogout bool `json:"enableFrontchannelLogout,omitempty"`
|
||||
BackchannelLogoutURL string `json:"backchannelLogoutURL,omitempty"`
|
||||
FrontchannelLogoutURL string `json:"frontchannelLogoutURL,omitempty"`
|
||||
// CACertPath is an optional filesystem path to a PEM-encoded CA bundle used
|
||||
// to verify the OIDC provider's TLS certificate. Use this when the provider
|
||||
// is signed by an internal/private CA that is not in the system trust store.
|
||||
CACertPath string `json:"caCertPath,omitempty"`
|
||||
// CACertPEM is an optional inline PEM-encoded CA bundle, equivalent to
|
||||
// CACertPath but supplied directly in the middleware configuration. Both
|
||||
// may be set; certificates from both sources are combined.
|
||||
CACertPEM string `json:"caCertPEM,omitempty"`
|
||||
// InsecureSkipVerify disables TLS certificate verification for the OIDC
|
||||
// provider. Intended ONLY for local development against self-signed
|
||||
// providers. Enabling this in production is a security hole — prefer
|
||||
// CACertPath/CACertPEM. Emits a loud warning at startup.
|
||||
InsecureSkipVerify bool `json:"insecureSkipVerify,omitempty"`
|
||||
}
|
||||
|
||||
// loadCACertPool assembles an x509.CertPool from CACertPath and CACertPEM.
|
||||
// Returns (nil, nil) when neither is configured — callers should fall back to
|
||||
// the system trust store. Returns a descriptive error if a PEM source is
|
||||
// configured but contains no parseable certificates, so misconfigurations
|
||||
// surface at startup rather than as unexplained TLS failures at runtime.
|
||||
func (c *Config) loadCACertPool() (*x509.CertPool, error) {
|
||||
if c.CACertPath == "" && c.CACertPEM == "" {
|
||||
return nil, nil
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if c.CACertPath != "" {
|
||||
data, err := os.ReadFile(c.CACertPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read caCertPath %q: %w", c.CACertPath, err)
|
||||
}
|
||||
if !pool.AppendCertsFromPEM(data) {
|
||||
return nil, fmt.Errorf("caCertPath %q: no valid PEM certificates found", c.CACertPath)
|
||||
}
|
||||
}
|
||||
if c.CACertPEM != "" {
|
||||
if !pool.AppendCertsFromPEM([]byte(c.CACertPEM)) {
|
||||
return nil, fmt.Errorf("caCertPEM: no valid PEM certificates found")
|
||||
}
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// RedisConfig configures Redis cache backend settings for distributed caching.
|
||||
@@ -98,8 +153,15 @@ type DynamicClientRegistrationConfig struct {
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
// StorageBackend specifies where to store DCR credentials: "file", "redis", or "auto"
|
||||
// - "file": Use file-based storage (default for backward compatibility)
|
||||
// - "redis": Use Redis exclusively (fails if Redis unavailable)
|
||||
// - "auto": Use Redis if available, fallback to file (default)
|
||||
StorageBackend string `json:"storageBackend,omitempty"`
|
||||
// RedisKeyPrefix is the prefix for Redis keys when using Redis storage (default: "dcr:creds:")
|
||||
RedisKeyPrefix string `json:"redisKeyPrefix,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
}
|
||||
|
||||
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
|
||||
@@ -194,6 +256,7 @@ func CreateConfig() *Config {
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
OverrideScopes: false, // Default to appending scopes, not overriding
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
MaxRefreshTokenAgeSeconds: 21600, // 6h - conservative heuristic, see field doc
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
Redis: nil, // Redis is disabled by default, configure via Traefik or env vars
|
||||
}
|
||||
@@ -317,6 +380,11 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate refresh-token max-age heuristic
|
||||
if c.MaxRefreshTokenAgeSeconds < 0 {
|
||||
return fmt.Errorf("maxRefreshTokenAgeSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate audience if specified
|
||||
if c.Audience != "" {
|
||||
// Validate audience format - should be a valid identifier or URL
|
||||
@@ -722,7 +790,18 @@ func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// IsDebug reports whether debug-level logging is enabled.
|
||||
// Callers should use this to avoid expensive format-string expansion
|
||||
// (e.g. on hot paths under yaegi) when debug output would be discarded.
|
||||
func (l *Logger) IsDebug() bool {
|
||||
if l == nil || l.logDebug == nil {
|
||||
return false
|
||||
}
|
||||
return l.logDebug.Writer() != io.Discard
|
||||
}
|
||||
|
||||
// newNoOpLogger creates a logger that discards all output.
|
||||
//
|
||||
// Deprecated: Use GetSingletonNoOpLogger() instead for better memory efficiency.
|
||||
func newNoOpLogger() *Logger {
|
||||
return GetSingletonNoOpLogger()
|
||||
@@ -737,15 +816,6 @@ func newNoOpLogger() *Logger {
|
||||
// - code: The HTTP status code for the response.
|
||||
// - logger: The Logger instance to use for logging the error.
|
||||
//
|
||||
// handleError writes an HTTP error response with the specified status code and message.
|
||||
// It logs the error and sets appropriate headers before writing the response.
|
||||
//
|
||||
//lint:ignore U1000 Kept for potential future error handling
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error("%s", message)
|
||||
http.Error(w, message, code)
|
||||
}
|
||||
|
||||
// GetSecurityHeadersApplier returns a function that applies security headers
|
||||
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
|
||||
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
|
||||
@@ -1051,111 +1121,6 @@ func (rc *RedisConfig) ApplyEnvFallbacks() {
|
||||
}
|
||||
}
|
||||
|
||||
// LoadRedisConfigFromEnv loads Redis configuration from environment variables.
|
||||
// Deprecated: Use RedisConfig.ApplyEnvFallbacks() on an existing config instead.
|
||||
// This function is kept for backward compatibility but should not be used directly.
|
||||
func LoadRedisConfigFromEnv() *RedisConfig {
|
||||
// Check if Redis is enabled
|
||||
enabledStr := os.Getenv("REDIS_ENABLED")
|
||||
if enabledStr == "" || enabledStr == "false" || enabledStr == "0" {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &RedisConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
// Parse numeric values
|
||||
if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
|
||||
if db, err := strconv.Atoi(dbStr); err == nil {
|
||||
config.DB = db
|
||||
}
|
||||
}
|
||||
|
||||
if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
|
||||
if poolSize, err := strconv.Atoi(poolSizeStr); err == nil {
|
||||
config.PoolSize = poolSize
|
||||
}
|
||||
}
|
||||
|
||||
if connectTimeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); connectTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(connectTimeoutStr); err == nil {
|
||||
config.ConnectTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
if readTimeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); readTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(readTimeoutStr); err == nil {
|
||||
config.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
if writeTimeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(writeTimeoutStr); err == nil {
|
||||
config.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Parse boolean values
|
||||
if enableTLSStr := os.Getenv("REDIS_ENABLE_TLS"); enableTLSStr == "true" || enableTLSStr == "1" {
|
||||
config.EnableTLS = true
|
||||
}
|
||||
|
||||
if skipVerifyStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipVerifyStr == "true" || skipVerifyStr == "1" {
|
||||
config.TLSSkipVerify = true
|
||||
}
|
||||
|
||||
// Parse hybrid mode settings
|
||||
if l1SizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); l1SizeStr != "" {
|
||||
if size, err := strconv.Atoi(l1SizeStr); err == nil {
|
||||
config.HybridL1Size = size
|
||||
}
|
||||
}
|
||||
|
||||
if l1MemoryStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); l1MemoryStr != "" {
|
||||
if memory, err := strconv.ParseInt(l1MemoryStr, 10, 64); err == nil {
|
||||
config.HybridL1MemoryMB = memory
|
||||
}
|
||||
}
|
||||
|
||||
// Parse circuit breaker settings
|
||||
if enableCBStr := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCBStr == "false" || enableCBStr == "0" {
|
||||
config.EnableCircuitBreaker = false
|
||||
} else {
|
||||
config.EnableCircuitBreaker = true // Default to enabled
|
||||
}
|
||||
|
||||
if cbThresholdStr := os.Getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD"); cbThresholdStr != "" {
|
||||
if threshold, err := strconv.Atoi(cbThresholdStr); err == nil {
|
||||
config.CircuitBreakerThreshold = threshold
|
||||
}
|
||||
}
|
||||
|
||||
if cbTimeoutStr := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(cbTimeoutStr); err == nil {
|
||||
config.CircuitBreakerTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Parse health check settings
|
||||
if enableHCStr := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHCStr == "false" || enableHCStr == "0" {
|
||||
config.EnableHealthCheck = false
|
||||
} else {
|
||||
config.EnableHealthCheck = true // Default to enabled
|
||||
}
|
||||
|
||||
if hcIntervalStr := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcIntervalStr != "" {
|
||||
if interval, err := strconv.Atoi(hcIntervalStr); err == nil {
|
||||
config.HealthCheckInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
// Apply defaults after loading from env
|
||||
config.ApplyDefaults()
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed || allowed == "*" {
|
||||
|
||||
+13
-6
@@ -548,17 +548,24 @@ func (gc *GenericCache) Delete(key string) {
|
||||
delete(gc.data, key)
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically cleans up the cache
|
||||
// cleanupRoutine periodically wipes the cache.
|
||||
//
|
||||
// NOTE: GenericCache does not track per-entry timestamps, so this is a
|
||||
// "clear-all on tick" strategy — every `gc.ttl` interval the entire map
|
||||
// is replaced, regardless of when each entry was written. This is the
|
||||
// intentional (simplified) behavior of GenericCache, which exists mainly
|
||||
// as a generic fallback for tests and non-typed caches. Callers that
|
||||
// require true per-entry TTL must use UniversalCache / UnifiedCache which
|
||||
// track expiry per entry.
|
||||
func (gc *GenericCache) cleanupRoutine() {
|
||||
ticker := time.NewTicker(gc.ttl)
|
||||
defer ticker.Stop()
|
||||
wipeTicker := time.NewTicker(gc.ttl)
|
||||
defer wipeTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-wipeTicker.C:
|
||||
gc.mu.Lock()
|
||||
// Simple cleanup - clear all data after TTL
|
||||
// In production, you'd track individual entry TTLs
|
||||
// Clear-all on tick, not per-entry TTL (see function doc).
|
||||
gc.data = make(map[string]interface{})
|
||||
gc.mu.Unlock()
|
||||
case <-gc.stopChan:
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -251,6 +254,30 @@ func TestSingletonResourceManager(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// createMockOIDCServer creates a mock OIDC server for testing
|
||||
func createMockOIDCServer() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/authorize",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"jwks_uri": "https://example.com/jwks",
|
||||
"userinfo_endpoint": "https://example.com/userinfo",
|
||||
})
|
||||
case "/jwks":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"keys": []interface{}{},
|
||||
})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// TestContextAwareGoroutineManagement tests context-aware goroutine management
|
||||
func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
t.Run("GoroutineCleanupOnContextCancel", func(t *testing.T) {
|
||||
@@ -259,13 +286,17 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
ResetUniversalCacheManagerForTesting()
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Create mock OIDC server
|
||||
mockServer := createMockOIDCServer()
|
||||
defer mockServer.Close()
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create a TraefikOidc instance with context
|
||||
config := &Config{
|
||||
ProviderURL: "https://example.com",
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
@@ -308,12 +339,20 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
ResetUniversalCacheManagerForTesting()
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Create mock OIDC servers
|
||||
mockServer1 := createMockOIDCServer()
|
||||
defer mockServer1.Close()
|
||||
mockServer2 := createMockOIDCServer()
|
||||
defer mockServer2.Close()
|
||||
mockServer3 := createMockOIDCServer()
|
||||
defer mockServer3.Close()
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
configs := []Config{
|
||||
{ProviderURL: "https://example1.com", ClientID: "client1", ClientSecret: "secret1"},
|
||||
{ProviderURL: "https://example2.com", ClientID: "client2", ClientSecret: "secret2"},
|
||||
{ProviderURL: "https://example3.com", ClientID: "client3", ClientSecret: "secret3"},
|
||||
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
|
||||
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
|
||||
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
|
||||
}
|
||||
|
||||
var plugins []*TraefikOidc
|
||||
@@ -366,6 +405,13 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
ResetUniversalCacheManagerForTesting()
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Create mock OIDC servers
|
||||
mockServers := make([]*httptest.Server, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
mockServers[i] = createMockOIDCServer()
|
||||
defer mockServers[i].Close()
|
||||
}
|
||||
|
||||
rm := GetResourceManager()
|
||||
|
||||
// Register singleton cleanup task
|
||||
@@ -386,7 +432,7 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
for i := 0; i < 3; i++ {
|
||||
ctx := context.Background()
|
||||
config := &Config{
|
||||
ProviderURL: fmt.Sprintf("https://example%d.com", i),
|
||||
ProviderURL: mockServers[i].URL,
|
||||
ClientID: fmt.Sprintf("client%d", i),
|
||||
ClientSecret: fmt.Sprintf("secret%d", i),
|
||||
}
|
||||
|
||||
@@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(tf.fixtures.UserEmail)
|
||||
session.SetUserIdentifier(tf.fixtures.UserEmail)
|
||||
session.SetAccessToken(tf.fixtures.AccessToken)
|
||||
session.SetRefreshToken(tf.fixtures.RefreshToken)
|
||||
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
|
||||
|
||||
@@ -22,7 +22,7 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
|
||||
// Test helper adapters for the new test files
|
||||
|
||||
// resetGlobalState resets all global singletons to prevent test interference
|
||||
// nolint:unused // Kept for potential future use in integration tests
|
||||
//nolint:unused // Kept for potential future use in integration tests
|
||||
/*
|
||||
func resetGlobalState() {
|
||||
// Reset global task registry first to stop all background tasks
|
||||
@@ -137,7 +137,7 @@ func (tc *testCleanup) cleanupAll() {
|
||||
}
|
||||
|
||||
// createTestConfig creates a config with all required fields populated for testing
|
||||
// nolint:unused // Kept for potential future use in integration tests
|
||||
//nolint:unused // Kept for potential future use in integration tests
|
||||
/*
|
||||
func createTestConfig() *Config {
|
||||
config := CreateConfig()
|
||||
@@ -151,7 +151,7 @@ func createTestConfig() *Config {
|
||||
*/
|
||||
|
||||
// setupTestOIDCMiddleware creates a test OIDC middleware instance with mock servers
|
||||
// nolint:unused // Kept for potential future use in integration tests
|
||||
//nolint:unused // Kept for potential future use in integration tests
|
||||
/*
|
||||
func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httptest.Server) {
|
||||
// Reset global state to ensure test isolation
|
||||
@@ -339,7 +339,7 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
|
||||
*/
|
||||
|
||||
// createMockJWT creates a mock JWT token for testing - adapter for existing tests
|
||||
// nolint:unused // Kept for potential future use in integration tests
|
||||
//nolint:unused // Kept for potential future use in integration tests
|
||||
/*
|
||||
func createMockJWT(t *testing.T, sub, email string) string {
|
||||
return ValidIDToken
|
||||
@@ -361,7 +361,7 @@ func createTestSession() *SessionData {
|
||||
}
|
||||
|
||||
// injectSessionIntoRequest saves the session and adds the resulting cookies to the request
|
||||
// nolint:unused // Kept for potential future use in integration tests
|
||||
//nolint:unused // Kept for potential future use in integration tests
|
||||
/*
|
||||
func injectSessionIntoRequest(t *testing.T, req *http.Request, session *SessionData) {
|
||||
// Create a response recorder to capture cookies
|
||||
|
||||
+132
-65
@@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -46,6 +47,17 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Hot-path fast-return: a previously-verified token has already passed
|
||||
// signature, claims, and replay checks. Skipping the parseJWT cost here
|
||||
// matters under bursty traffic (e.g. 10+ concurrent panel requests on
|
||||
// every Grafana dashboard refresh) where the same token is validated
|
||||
// dozens of times per second by validateStandardTokens.
|
||||
if t.tokenCache != nil {
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
parsedJWT, parseErr := parseJWT(token)
|
||||
if parseErr != nil {
|
||||
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
|
||||
@@ -63,12 +75,6 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check token cache FIRST - if token is already verified and cached, return immediately
|
||||
// This prevents false positives when multiple goroutines validate the same token concurrently
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only check JTI blacklist for tokens that aren't already in the cache
|
||||
// This is for FIRST-TIME validation to detect replay attacks
|
||||
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
|
||||
@@ -315,15 +321,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
jwksURL := t.jwksURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
if !t.suppressDiagnosticLogs && jwks != nil {
|
||||
t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), jwksURL)
|
||||
}
|
||||
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
@@ -337,38 +334,12 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
t.safeLogDebugf("DIAGNOSTIC: Looking for kid=%s, alg=%s in JWKS", kid, alg)
|
||||
}
|
||||
|
||||
if jwks == nil {
|
||||
return fmt.Errorf("JWKS is nil, cannot verify token")
|
||||
}
|
||||
|
||||
// Find the matching key in JWKS
|
||||
var matchingKey *JWK
|
||||
availableKids := make([]string, 0, len(jwks.Keys))
|
||||
for _, key := range jwks.Keys {
|
||||
availableKids = append(availableKids, key.Kid)
|
||||
if key.Kid == kid {
|
||||
matchingKey = &key
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchingKey == nil {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogErrorf("DIAGNOSTIC: No matching key found for kid=%s. Available kids: %v", kid, availableKids)
|
||||
}
|
||||
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogDebugf("DIAGNOSTIC: Found matching key for kid=%s, key type: %s", kid, matchingKey.Kty)
|
||||
}
|
||||
|
||||
publicKeyPEM, err := jwkToPEM(matchingKey)
|
||||
pubKey, err := t.jwkCache.GetPublicKey(context.Background(), jwksURL, kid, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||
return fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
|
||||
if err := verifySignatureWithKey(token, pubKey, alg); err != nil {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
|
||||
}
|
||||
@@ -451,10 +422,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
}
|
||||
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
|
||||
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
||||
newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
//nolint:gocritic // Complex error handling with provider-specific conditions
|
||||
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
|
||||
t.logger.Debug("Refresh token expired or revoked: %v", err)
|
||||
// Clear all tokens and authentication state when refresh token is invalid
|
||||
@@ -464,7 +434,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetRefreshToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens as well to prevent any replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -506,12 +476,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
|
||||
return false
|
||||
}
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
|
||||
return false
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim)
|
||||
return false
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
|
||||
// Get token expiry information for logging
|
||||
var expiryTime time.Time
|
||||
@@ -537,7 +513,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -554,6 +530,91 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return true
|
||||
}
|
||||
|
||||
// coordinatedTokenRefresh routes a refresh-token grant through the
|
||||
// RefreshCoordinator so that concurrent requests sharing the same refresh
|
||||
// token coalesce into a single upstream call. This prevents the thundering
|
||||
// herd that yields invalid_grant when the IdP rotates refresh tokens.
|
||||
//
|
||||
// Falls back to a direct call when the coordinator is nil, which only
|
||||
// happens in tests that build TraefikOidc literals without going through
|
||||
// NewWithContext.
|
||||
func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) {
|
||||
if t.refreshCoordinator == nil {
|
||||
return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
parentCtx := context.Background()
|
||||
if req != nil {
|
||||
parentCtx = req.Context()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout)
|
||||
defer cancel()
|
||||
|
||||
sessionID := refreshCoordinatorSessionID(refreshToken)
|
||||
|
||||
return t.refreshCoordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
sessionID,
|
||||
refreshToken,
|
||||
func() (*TokenResponse, error) {
|
||||
// Cross-replica dedup. The in-process coordinator already
|
||||
// collapses concurrent grants on this pod; this Redis-backed
|
||||
// short-TTL cache covers the (rare) case of a failover or
|
||||
// load-balancer reroute mid-refresh, where two pods would
|
||||
// otherwise both POST the same refresh_token to the IdP.
|
||||
if cached, ok := t.lookupCachedRefreshResult(sessionID); ok {
|
||||
return cached, nil
|
||||
}
|
||||
resp, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
if err == nil && resp != nil {
|
||||
t.cacheRefreshResult(sessionID, resp)
|
||||
}
|
||||
return resp, err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// lookupCachedRefreshResult returns a previously-stored TokenResponse for the
|
||||
// given refresh-token hash, if one exists and is still within its short TTL.
|
||||
// The cache wraps the universal cache, which is Redis-backed in production -
|
||||
// so a "hit" here means another Traefik replica refreshed this same token
|
||||
// within the last few seconds.
|
||||
func (t *TraefikOidc) lookupCachedRefreshResult(sessionID string) (*TokenResponse, bool) {
|
||||
if t.refreshResultCache == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := t.refreshResultCache.Get(refreshResultCacheKey(sessionID))
|
||||
if !ok || v == nil {
|
||||
return nil, false
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); ok && tr != nil {
|
||||
return tr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// cacheRefreshResult stores the new TokenResponse under the refresh-token
|
||||
// hash for a short window. TTL is intentionally tight: the rotated refresh
|
||||
// token cannot be re-presented to the IdP, and any peer waiting longer than
|
||||
// this window has almost certainly given up via its own coordinator timeout.
|
||||
func (t *TraefikOidc) cacheRefreshResult(sessionID string, resp *TokenResponse) {
|
||||
if t.refreshResultCache == nil || resp == nil {
|
||||
return
|
||||
}
|
||||
t.refreshResultCache.Set(refreshResultCacheKey(sessionID), resp, refreshResultCacheTTL)
|
||||
}
|
||||
|
||||
// refreshResultCacheKey namespaces refresh-result entries inside the shared
|
||||
// cache namespace.
|
||||
func refreshResultCacheKey(sessionID string) string {
|
||||
return "rt-result:" + sessionID
|
||||
}
|
||||
|
||||
// refreshResultCacheTTL bounds how long a peer can lean on the dedup cache.
|
||||
// Long enough for a sibling replica to observe the result, short enough that
|
||||
// a stale entry never re-supplies a token after the IdP has already moved on.
|
||||
const refreshResultCacheTTL = 5 * time.Second
|
||||
|
||||
// RevokeToken revokes a token locally by adding it to the blacklist cache.
|
||||
// It removes the token from the verification cache and adds both the token
|
||||
// and its JTI (if present) to the blacklist to prevent future use.
|
||||
@@ -1139,9 +1200,14 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
sessionManager := t.sessionManager
|
||||
logger := t.logger
|
||||
|
||||
// Only use the fast cleanup interval when actually running under `go test`.
|
||||
// runtime.Compiler == "yaegi" makes isTestMode() return true in production
|
||||
// (Traefik interprets the plugin via yaegi), which would otherwise pin this
|
||||
// ticker to 20 Hz on a real cluster despite tokenCache.Cleanup and
|
||||
// jwkCache.Cleanup both being no-ops there.
|
||||
cleanupInterval := 1 * time.Minute
|
||||
if isTestMode() {
|
||||
cleanupInterval = 50 * time.Millisecond // Fast interval for tests
|
||||
if isTestMode() && runtime.Compiler != "yaegi" {
|
||||
cleanupInterval = 50 * time.Millisecond
|
||||
}
|
||||
|
||||
// Create cleanup function
|
||||
@@ -1183,25 +1249,27 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
}
|
||||
|
||||
// extractGroupsAndRoles extracts group and role information from token claims.
|
||||
// It parses the 'groups' and 'roles' claims from the ID token and validates their format.
|
||||
// Parameters:
|
||||
// - idToken: The ID token containing claims to extract.
|
||||
// It parses the configured group/role claims from the supplied ID token.
|
||||
//
|
||||
// Returns:
|
||||
// - groups: Array of group names from the 'groups' claim.
|
||||
// - roles: Array of role names from the 'roles' claim.
|
||||
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present
|
||||
// but not arrays of strings.
|
||||
// Most callers should prefer extractGroupsAndRolesFromClaims when claims have
|
||||
// already been parsed for the request (e.g. via SessionData.GetIDTokenClaims),
|
||||
// to avoid re-parsing the JWT.
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
|
||||
}
|
||||
return t.extractGroupsAndRolesFromClaims(claims)
|
||||
}
|
||||
|
||||
// extractGroupsAndRolesFromClaims extracts group and role information from
|
||||
// already-parsed claims. Hot path: callers that have a cached claims map (such
|
||||
// as SessionData.GetIDTokenClaims) should use this to skip a redundant
|
||||
// base64+JSON decode of the JWT on every authenticated request.
|
||||
func (t *TraefikOidc) extractGroupsAndRolesFromClaims(claims map[string]interface{}) ([]string, []string, error) {
|
||||
var groups []string
|
||||
var roles []string
|
||||
|
||||
// Extract groups using configurable claim name (defaults to "groups")
|
||||
if groupsClaim, exists := claims[t.groupClaimName]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
@@ -1217,7 +1285,6 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles using configurable claim name (defaults to "roles")
|
||||
if rolesClaim, exists := claims[t.roleClaimName]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
|
||||
@@ -95,6 +95,7 @@ type TraefikOidc struct {
|
||||
cancelFunc context.CancelFunc
|
||||
errorRecoveryManager *ErrorRecoveryManager
|
||||
tokenResilienceManager *TokenResilienceManager
|
||||
refreshCoordinator *RefreshCoordinator
|
||||
goroutineWG *sync.WaitGroup
|
||||
dcrConfig *DynamicClientRegistrationConfig
|
||||
dynamicClientRegistrar *DynamicClientRegistrar
|
||||
@@ -119,14 +120,22 @@ type TraefikOidc struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
registrationURL string
|
||||
backchannelLogoutPath string
|
||||
frontchannelLogoutPath string
|
||||
scopesSupported []string
|
||||
scopes []string
|
||||
refreshGracePeriod time.Duration
|
||||
maxRefreshTokenAge time.Duration
|
||||
metadataMu sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
metadataRetryMutex sync.Mutex
|
||||
firstRequestMutex sync.Mutex
|
||||
sessionInvalidationCache CacheInterface
|
||||
refreshResultCache CacheInterface
|
||||
minimalHeaders bool
|
||||
stripAuthCookies bool
|
||||
enableBackchannelLogout bool
|
||||
enableFrontchannelLogout bool
|
||||
firstRequestReceived bool
|
||||
requireTokenIntrospection bool
|
||||
metadataRefreshStarted bool
|
||||
|
||||
+145
-28
@@ -21,6 +21,10 @@ const (
|
||||
CacheTypeJWK CacheType = "jwk"
|
||||
CacheTypeSession CacheType = "session"
|
||||
CacheTypeGeneral CacheType = "general"
|
||||
|
||||
// maxCacheEntrySize defines the maximum size for a single cache entry (64 MiB)
|
||||
// This prevents integer overflow when allocating memory for serialization
|
||||
maxCacheEntrySize = 64 * 1024 * 1024
|
||||
)
|
||||
|
||||
// UniversalCacheConfig provides configuration for the universal cache
|
||||
@@ -248,6 +252,25 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
|
||||
}
|
||||
}
|
||||
|
||||
return c.setLocal(key, value, ttl)
|
||||
}
|
||||
|
||||
// SetLocal stores a value only in the in-memory LRU, bypassing any
|
||||
// distributed backend. Use for values that don't survive JSON round-tripping
|
||||
// — interfaces holding concrete crypto keys, *big.Int, or types whose
|
||||
// unexported fields yaegi exposes under an X prefix on Marshal. Each replica
|
||||
// caches independently; correctness must not depend on cross-replica
|
||||
// coherence for these keys.
|
||||
func (c *UniversalCache) SetLocal(key string, value interface{}, ttl time.Duration) error {
|
||||
if ttl == 0 {
|
||||
ttl = c.config.DefaultTTL
|
||||
}
|
||||
return c.setLocal(key, value, ttl)
|
||||
}
|
||||
|
||||
// setLocal performs the in-memory portion of a write. ttl must already be
|
||||
// resolved against DefaultTTL by the caller.
|
||||
func (c *UniversalCache) setLocal(key string, value interface{}, ttl time.Duration) error {
|
||||
size := c.estimateSize(value)
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -302,8 +325,10 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
|
||||
c.currentMemory += size
|
||||
}
|
||||
|
||||
c.logger.Debugf("UniversalCache[%s]: Set key=%s, ttl=%v, size=%d bytes",
|
||||
c.config.Type, key, ttl, size)
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debugf("UniversalCache[%s]: Set key=%s, ttl=%v, size=%d bytes",
|
||||
c.config.Type, key, ttl, size)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -327,15 +352,54 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) {
|
||||
// Fall through to local cache
|
||||
} else {
|
||||
atomic.AddInt64(&c.hits, 1)
|
||||
// Update local cache with backend value
|
||||
go func() {
|
||||
_ = c.updateLocalCache(key, value, c.config.DefaultTTL)
|
||||
}()
|
||||
// Update local cache with backend value synchronously.
|
||||
// Under yaegi, goroutine spawn is 5-10x costlier than compiled Go,
|
||||
// and this path fires per-request on cold local cache.
|
||||
// updateLocalCache is cheap (map write under mutex).
|
||||
_ = c.updateLocalCache(key, value, c.config.DefaultTTL)
|
||||
return value, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.getLocal(key)
|
||||
}
|
||||
|
||||
// GetLocal retrieves a value only from the in-memory LRU, never querying the
|
||||
// distributed backend. Pair with SetLocal for values that aren't safe to
|
||||
// serialize (see SetLocal docstring).
|
||||
func (c *UniversalCache) GetLocal(key string) (interface{}, bool) {
|
||||
return c.getLocal(key)
|
||||
}
|
||||
|
||||
// getLocal returns the in-memory entry for key honoring expiry, grace
|
||||
// periods, and the RLock fast path used by token/JWK/session caches.
|
||||
func (c *UniversalCache) getLocal(key string) (interface{}, bool) {
|
||||
// Fast read path for caches whose eviction is dominated by TTL rather than
|
||||
// access-recency (token, JWK, session). Holding only an RLock here lets all
|
||||
// concurrent readers verify cached tokens in parallel — under yaegi the
|
||||
// previous unconditional Lock serialized every JWT verify on a single
|
||||
// mutex and pinned a CPU under load.
|
||||
switch c.config.Type {
|
||||
case CacheTypeToken, CacheTypeJWK, CacheTypeSession:
|
||||
c.mu.RLock()
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
c.mu.RUnlock()
|
||||
atomic.AddInt64(&c.misses, 1)
|
||||
return nil, false
|
||||
}
|
||||
if !time.Now().After(item.ExpiresAt) {
|
||||
value := item.Value
|
||||
c.mu.RUnlock()
|
||||
atomic.AddInt64(&c.hits, 1)
|
||||
return value, true
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
// Expired — fall through to the write-locked slow path below to
|
||||
// remove the entry under exclusive access.
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -436,7 +500,7 @@ func (c *UniversalCache) Clear() {
|
||||
c.currentSize = 0
|
||||
c.currentMemory = 0
|
||||
|
||||
c.logger.Infof("UniversalCache[%s]: Cleared all items", c.config.Type)
|
||||
c.logger.Debugf("UniversalCache[%s]: Cleared all items", c.config.Type)
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
@@ -536,7 +600,9 @@ func (c *UniversalCache) evictOldest() {
|
||||
if item, exists := c.items[key]; exists {
|
||||
c.removeItem(key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -720,22 +786,6 @@ func (c *UniversalCache) SetWithMetadata(key string, value interface{}, ttl time
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTyped retrieves a typed value from the cache
|
||||
func GetTyped[T any](c *UniversalCache, key string) (T, bool) {
|
||||
var zero T
|
||||
value, exists := c.Get(key)
|
||||
if !exists {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
typed, ok := value.(T)
|
||||
if !ok {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
return typed, true
|
||||
}
|
||||
|
||||
// TokenCacheOperations provides token-specific operations
|
||||
func (c *UniversalCache) BlacklistToken(token string, ttl time.Duration) error {
|
||||
if c.config.Type != CacheTypeToken {
|
||||
@@ -784,14 +834,81 @@ func (c *UniversalCache) Strategy() CacheStrategy {
|
||||
|
||||
// serialize converts a value to bytes for backend storage
|
||||
func (c *UniversalCache) serialize(value interface{}) ([]byte, error) {
|
||||
// Use JSON for serialization - simple and universal
|
||||
return json.Marshal(value)
|
||||
// If value is already a byte slice (e.g., pre-marshaled JSON from metadata_cache),
|
||||
// store it directly with a marker to prevent double-encoding.
|
||||
// This fixes the issue where []byte was being JSON-marshaled, causing Base64 encoding.
|
||||
if bytes, ok := value.([]byte); ok {
|
||||
// Validate size to prevent integer overflow
|
||||
if len(bytes) > maxCacheEntrySize {
|
||||
return nil, fmt.Errorf("cache entry size %d exceeds maximum allowed size %d", len(bytes), maxCacheEntrySize)
|
||||
}
|
||||
// Check for potential overflow when adding marker byte
|
||||
if len(bytes) == maxCacheEntrySize {
|
||||
return nil, fmt.Errorf("cache entry size would overflow when adding marker byte")
|
||||
}
|
||||
|
||||
// Prepend marker byte 0x00 to indicate raw bytes (not JSON-encoded)
|
||||
result := make([]byte, len(bytes)+1)
|
||||
result[0] = 0x00
|
||||
copy(result[1:], bytes)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// For all other types (maps, strings, etc.), use JSON encoding
|
||||
// Prepend marker byte 0x01 to indicate JSON-encoded data
|
||||
jsonData, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate size to prevent integer overflow
|
||||
if len(jsonData) > maxCacheEntrySize {
|
||||
return nil, fmt.Errorf("serialized cache entry size %d exceeds maximum allowed size %d", len(jsonData), maxCacheEntrySize)
|
||||
}
|
||||
// Check for potential overflow when adding marker byte
|
||||
if len(jsonData) == maxCacheEntrySize {
|
||||
return nil, fmt.Errorf("serialized cache entry size would overflow when adding marker byte")
|
||||
}
|
||||
|
||||
result := make([]byte, len(jsonData)+1)
|
||||
result[0] = 0x01
|
||||
copy(result[1:], jsonData)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// deserialize converts bytes from backend storage to a value
|
||||
func (c *UniversalCache) deserialize(data []byte, value interface{}) error {
|
||||
// Use JSON for deserialization
|
||||
return json.Unmarshal(data, value)
|
||||
if len(data) == 0 {
|
||||
return fmt.Errorf("cannot deserialize empty data")
|
||||
}
|
||||
|
||||
// Check for type marker (added by serialize)
|
||||
if data[0] == 0x00 {
|
||||
// Raw bytes - strip marker and return as-is
|
||||
rawBytes := data[1:]
|
||||
if ptr, ok := value.(*interface{}); ok {
|
||||
*ptr = rawBytes
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("cannot deserialize raw bytes into %T", value)
|
||||
}
|
||||
|
||||
if data[0] == 0x01 {
|
||||
// JSON-encoded - strip marker and unmarshal
|
||||
return json.Unmarshal(data[1:], value)
|
||||
}
|
||||
|
||||
// Legacy data without marker (for backward compatibility)
|
||||
// Try to unmarshal as JSON
|
||||
if err := json.Unmarshal(data, value); err != nil {
|
||||
// If unmarshal fails, treat as raw bytes
|
||||
if ptr, ok := value.(*interface{}); ok {
|
||||
*ptr = data
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// prefixKey adds a cache type prefix to the key for backend storage
|
||||
|
||||
@@ -0,0 +1,517 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestUniversalCache_SerializeDeserialize tests the fix for issue #116
|
||||
// where metadata was stored as Base64-encoded JSON but read as plain JSON
|
||||
func TestUniversalCache_SerializeDeserialize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RawBytesPreserved", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Test data: pre-marshaled JSON bytes (like metadata_cache uses)
|
||||
testData := []byte(`{"issuer":"https://example.com","jwks_uri":"https://example.com/jwks"}`)
|
||||
|
||||
// Serialize
|
||||
serialized, err := cache.serialize(testData)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, serialized)
|
||||
|
||||
// Should have marker byte
|
||||
assert.Equal(t, byte(0x00), serialized[0], "Should have raw bytes marker")
|
||||
assert.Equal(t, testData, serialized[1:], "Data should be preserved after marker")
|
||||
|
||||
// Deserialize
|
||||
var result interface{}
|
||||
err = cache.deserialize(serialized, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should get back []byte
|
||||
resultBytes, ok := result.([]byte)
|
||||
require.True(t, ok, "Result should be []byte")
|
||||
assert.Equal(t, testData, resultBytes, "Deserialized data should match original")
|
||||
})
|
||||
|
||||
t.Run("JSONEncodedTypes", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
}{
|
||||
{
|
||||
name: "Map",
|
||||
value: map[string]interface{}{"key": "value", "number": 42.0},
|
||||
},
|
||||
{
|
||||
name: "String",
|
||||
value: "test-string",
|
||||
},
|
||||
{
|
||||
name: "Number",
|
||||
value: 123.456,
|
||||
},
|
||||
{
|
||||
name: "Array",
|
||||
value: []interface{}{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Serialize
|
||||
serialized, err := cache.serialize(tc.value)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, serialized)
|
||||
|
||||
// Should have JSON marker byte
|
||||
assert.Equal(t, byte(0x01), serialized[0], "Should have JSON marker")
|
||||
|
||||
// Verify the JSON portion is valid
|
||||
var checkJSON interface{}
|
||||
err = json.Unmarshal(serialized[1:], &checkJSON)
|
||||
require.NoError(t, err, "Should be valid JSON after marker")
|
||||
|
||||
// Deserialize
|
||||
var result interface{}
|
||||
err = cache.deserialize(serialized, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare results (using JSON round-trip for consistent comparison)
|
||||
expectedJSON, _ := json.Marshal(tc.value)
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
assert.JSONEq(t, string(expectedJSON), string(resultJSON), "Deserialized data should match original")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LegacyDataCompatibility", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Simulate legacy data (JSON without marker byte)
|
||||
legacyData := []byte(`{"legacy":"data"}`)
|
||||
|
||||
var result interface{}
|
||||
err := cache.deserialize(legacyData, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should successfully unmarshal as JSON
|
||||
resultMap, ok := result.(map[string]interface{})
|
||||
require.True(t, ok, "Should unmarshal legacy JSON data")
|
||||
assert.Equal(t, "data", resultMap["legacy"])
|
||||
})
|
||||
|
||||
t.Run("EmptyDataHandling", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
var result interface{}
|
||||
err := cache.deserialize([]byte{}, &result)
|
||||
assert.Error(t, err, "Should error on empty data")
|
||||
assert.Contains(t, err.Error(), "empty data")
|
||||
})
|
||||
|
||||
t.Run("OverflowProtection_LargeBytes", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Create a byte slice that exceeds maxCacheEntrySize (64 MiB)
|
||||
oversizedBytes := make([]byte, 65*1024*1024) // 65 MiB
|
||||
|
||||
// Attempt to serialize - should fail with overflow error
|
||||
_, err := cache.serialize(oversizedBytes)
|
||||
require.Error(t, err, "Should error on oversized byte slice")
|
||||
assert.Contains(t, err.Error(), "exceeds maximum allowed size")
|
||||
})
|
||||
|
||||
t.Run("OverflowProtection_ExactMaxSize", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Create a byte slice exactly at maxCacheEntrySize
|
||||
// This should fail because adding marker byte would overflow
|
||||
exactMaxBytes := make([]byte, 64*1024*1024) // Exactly 64 MiB
|
||||
|
||||
_, err := cache.serialize(exactMaxBytes)
|
||||
require.Error(t, err, "Should error when adding marker would overflow")
|
||||
assert.Contains(t, err.Error(), "would overflow when adding marker byte")
|
||||
})
|
||||
|
||||
t.Run("OverflowProtection_SafeSize", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Create a byte slice well within limits
|
||||
safeBytes := make([]byte, 1024*1024) // 1 MiB - safe size
|
||||
|
||||
serialized, err := cache.serialize(safeBytes)
|
||||
require.NoError(t, err, "Should succeed with safe size")
|
||||
assert.NotNil(t, serialized)
|
||||
assert.Equal(t, len(safeBytes)+1, len(serialized), "Should add marker byte")
|
||||
})
|
||||
|
||||
t.Run("OverflowProtection_JSONData", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Create a very large map that will exceed limits when JSON-encoded
|
||||
largeMap := make(map[string]string)
|
||||
// Each entry is roughly 50 bytes, so we need ~1.3M entries to exceed 64 MiB
|
||||
for i := 0; i < 1400000; i++ {
|
||||
key := fmt.Sprintf("key_%d", i)
|
||||
largeMap[key] = "value_with_some_content_to_make_it_larger"
|
||||
}
|
||||
|
||||
_, err := cache.serialize(largeMap)
|
||||
require.Error(t, err, "Should error when JSON serialization exceeds size limit")
|
||||
assert.Contains(t, err.Error(), "exceeds maximum allowed size")
|
||||
})
|
||||
}
|
||||
|
||||
// TestUniversalCache_RedisIntegration_Issue116 tests the complete fix for issue #116
|
||||
// with actual Redis backend to ensure metadata cache works correctly
|
||||
func TestUniversalCache_RedisIntegration_Issue116(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Start miniredis server
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
// Create Redis backend
|
||||
redisConfig := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisConfig.RedisPrefix = "test:"
|
||||
backend, err := backends.NewRedisBackend(redisConfig)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
t.Run("MetadataCache_StoreAndRetrieve", func(t *testing.T) {
|
||||
// Create cache with Redis backend
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeMetadata,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
// Simulate metadata_cache.Set behavior:
|
||||
// 1. Marshal metadata to JSON
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
JWKSURL: "https://example.com/jwks",
|
||||
TokenURL: "https://example.com/token",
|
||||
AuthURL: "https://example.com/authorize",
|
||||
}
|
||||
jsonData, err := json.Marshal(metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Store the JSON bytes
|
||||
key := "v2:https://example.com"
|
||||
err = cache.Set(key, jsonData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Retrieve the data
|
||||
retrieved, exists := cache.Get(key)
|
||||
require.True(t, exists, "Data should exist in cache")
|
||||
|
||||
// 4. Should get back []byte (not a string or map)
|
||||
retrievedBytes, ok := retrieved.([]byte)
|
||||
require.True(t, ok, "Retrieved value should be []byte, got %T", retrieved)
|
||||
|
||||
// 5. Should be able to unmarshal as JSON
|
||||
var retrievedMetadata ProviderMetadata
|
||||
err = json.Unmarshal(retrievedBytes, &retrievedMetadata)
|
||||
require.NoError(t, err, "Should be able to unmarshal retrieved bytes as JSON")
|
||||
|
||||
// 6. Verify data integrity
|
||||
assert.Equal(t, metadata.Issuer, retrievedMetadata.Issuer)
|
||||
assert.Equal(t, metadata.JWKSURL, retrievedMetadata.JWKSURL)
|
||||
assert.Equal(t, metadata.TokenURL, retrievedMetadata.TokenURL)
|
||||
})
|
||||
|
||||
t.Run("MetadataCache_NoBase64Encoding", func(t *testing.T) {
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeMetadata,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
// Store JSON bytes
|
||||
jsonData := []byte(`{"issuer":"https://test.com"}`)
|
||||
key := "v2:https://test.com"
|
||||
err = cache.Set(key, jsonData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve
|
||||
retrieved, exists := cache.Get(key)
|
||||
require.True(t, exists)
|
||||
|
||||
retrievedBytes, ok := retrieved.([]byte)
|
||||
require.True(t, ok)
|
||||
|
||||
// The retrieved data should NOT start with "eyJ" (Base64 encoding of "{")
|
||||
// This was the bug in issue #116
|
||||
assert.NotEqual(t, []byte("eyJ"), retrievedBytes[:3], "Data should not be Base64 encoded")
|
||||
|
||||
// Should be valid JSON
|
||||
var checkJSON map[string]interface{}
|
||||
err = json.Unmarshal(retrievedBytes, &checkJSON)
|
||||
require.NoError(t, err, "Data should be valid JSON")
|
||||
assert.Equal(t, "https://test.com", checkJSON["issuer"])
|
||||
})
|
||||
|
||||
t.Run("TokenCache_MapValues", func(t *testing.T) {
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
// Store a map (like TokenCache does)
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": 1234567890.0,
|
||||
"scope": "read write",
|
||||
}
|
||||
key := "token:abc123"
|
||||
err = cache.Set(key, claims, 10*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve
|
||||
retrieved, exists := cache.Get(key)
|
||||
require.True(t, exists)
|
||||
|
||||
// Should get back a map
|
||||
retrievedMap, ok := retrieved.(map[string]interface{})
|
||||
require.True(t, ok, "Retrieved value should be map[string]interface{}")
|
||||
assert.Equal(t, "user123", retrievedMap["sub"])
|
||||
assert.Equal(t, 1234567890.0, retrievedMap["exp"])
|
||||
})
|
||||
|
||||
t.Run("MixedTypes_SameCache", func(t *testing.T) {
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
// Store different types
|
||||
jsonBytes := []byte(`{"type":"json-bytes"}`)
|
||||
err = cache.Set("key1", jsonBytes, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
mapData := map[string]interface{}{"type": "map"}
|
||||
err = cache.Set("key2", mapData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
stringData := "plain-string"
|
||||
err = cache.Set("key3", stringData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify each type
|
||||
val1, exists := cache.Get("key1")
|
||||
require.True(t, exists)
|
||||
bytes1, ok := val1.([]byte)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jsonBytes, bytes1)
|
||||
|
||||
val2, exists := cache.Get("key2")
|
||||
require.True(t, exists)
|
||||
map2, ok := val2.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "map", map2["type"])
|
||||
|
||||
val3, exists := cache.Get("key3")
|
||||
require.True(t, exists)
|
||||
str3, ok := val3.(string)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, stringData, str3)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUniversalCache_BackwardCompatibility tests that old cached data is handled gracefully
|
||||
func TestUniversalCache_BackwardCompatibility(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisConfig := backends.DefaultRedisConfig(mr.Addr())
|
||||
backend, err := backends.NewRedisBackend(redisConfig)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("LegacyJSONData", func(t *testing.T) {
|
||||
// Manually insert legacy data (plain JSON without marker)
|
||||
legacyKey := "general:legacy-key"
|
||||
legacyData := []byte(`{"old":"format"}`)
|
||||
err = backend.Set(ctx, legacyKey, legacyData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to retrieve via UniversalCache
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
retrieved, exists := cache.Get("legacy-key")
|
||||
require.True(t, exists, "Should retrieve legacy data")
|
||||
|
||||
// Should deserialize as JSON map
|
||||
retrievedMap, ok := retrieved.(map[string]interface{})
|
||||
require.True(t, ok, "Should unmarshal legacy JSON")
|
||||
assert.Equal(t, "format", retrievedMap["old"])
|
||||
})
|
||||
|
||||
t.Run("LegacyCorruptData", func(t *testing.T) {
|
||||
// Insert corrupt/invalid data
|
||||
corruptKey := "general:corrupt-key"
|
||||
corruptData := []byte("not json and no marker")
|
||||
err = backend.Set(ctx, corruptKey, corruptData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
retrieved, exists := cache.Get("corrupt-key")
|
||||
require.True(t, exists)
|
||||
|
||||
// Should return as raw bytes (fallback)
|
||||
retrievedBytes, ok := retrieved.([]byte)
|
||||
require.True(t, ok, "Should return corrupt data as raw bytes")
|
||||
assert.Equal(t, corruptData, retrievedBytes)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMetadataCache_Issue116_Regression is the main regression test for issue #116
|
||||
// This specifically tests the scenario described in the GitHub issue
|
||||
func TestMetadataCache_Issue116_Regression(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
// Create Redis backend
|
||||
redisConfig := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisConfig.RedisPrefix = "traefik:"
|
||||
backend, err := backends.NewRedisBackend(redisConfig)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
// Create a simple logger
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
// Create metadata cache instance
|
||||
metadataCache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeMetadata,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true,
|
||||
}, backend)
|
||||
defer metadataCache.Close()
|
||||
|
||||
// Use the actual MetadataCache wrapper
|
||||
wg := &sync.WaitGroup{}
|
||||
mc := &MetadataCache{
|
||||
cache: metadataCache,
|
||||
logger: logger,
|
||||
wg: wg,
|
||||
}
|
||||
|
||||
// Test: Store and retrieve metadata (the scenario from issue #116)
|
||||
providerURL := "https://example.com"
|
||||
metadata := &ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
AuthURL: "https://example.com/authorize",
|
||||
TokenURL: "https://example.com/token",
|
||||
JWKSURL: "https://example.com/jwks",
|
||||
RevokeURL: "https://example.com/revoke",
|
||||
EndSessionURL: "https://example.com/logout",
|
||||
RegistrationURL: "https://example.com/register",
|
||||
ScopesSupported: []string{"openid", "profile", "email"},
|
||||
}
|
||||
|
||||
// Store metadata
|
||||
err = mc.Set(providerURL, metadata, 1*time.Hour)
|
||||
require.NoError(t, err, "Should store metadata without error")
|
||||
|
||||
// Retrieve metadata
|
||||
retrieved, exists := mc.Get(providerURL)
|
||||
require.True(t, exists, "Should retrieve stored metadata")
|
||||
require.NotNil(t, retrieved, "Retrieved metadata should not be nil")
|
||||
|
||||
// Verify no corruption - this was failing in issue #116 with "invalid character 'e'" error
|
||||
assert.Equal(t, metadata.Issuer, retrieved.Issuer)
|
||||
assert.Equal(t, metadata.AuthURL, retrieved.AuthURL)
|
||||
assert.Equal(t, metadata.TokenURL, retrieved.TokenURL)
|
||||
assert.Equal(t, metadata.JWKSURL, retrieved.JWKSURL)
|
||||
|
||||
// Verify the data is not Base64-encoded in Redis
|
||||
// This checks the root cause mentioned in the issue
|
||||
ctx := context.Background()
|
||||
rawData, _, exists, err := backend.Get(ctx, "metadata:v2:"+providerURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
|
||||
// Strip the marker byte
|
||||
require.Greater(t, len(rawData), 1, "Data should have marker byte")
|
||||
dataWithoutMarker := rawData[1:]
|
||||
|
||||
// Should not start with "eyJ" (Base64 encoding of "{")
|
||||
if len(dataWithoutMarker) >= 3 {
|
||||
assert.NotEqual(t, "eyJ", string(dataWithoutMarker[:3]), "Data should not be Base64-encoded")
|
||||
}
|
||||
|
||||
// Should be valid JSON
|
||||
var checkMetadata ProviderMetadata
|
||||
err = json.Unmarshal(dataWithoutMarker, &checkMetadata)
|
||||
require.NoError(t, err, "Stored data should be valid JSON, not Base64")
|
||||
assert.Equal(t, metadata.Issuer, checkMetadata.Issuer)
|
||||
}
|
||||
+108
-51
@@ -13,20 +13,23 @@ import (
|
||||
// It runs a single consolidated cleanup goroutine for all caches, reducing
|
||||
// goroutine count and CPU overhead compared to per-cache cleanup routines.
|
||||
type UniversalCacheManager struct {
|
||||
sharedBackend backends.CacheBackend
|
||||
ctx context.Context
|
||||
tokenTypeCache *UniversalCache
|
||||
jwkCache *UniversalCache
|
||||
sessionCache *UniversalCache
|
||||
introspectionCache *UniversalCache
|
||||
tokenCache *UniversalCache
|
||||
metadataCache *UniversalCache
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
cleanupStarted bool
|
||||
sharedBackend backends.CacheBackend
|
||||
ctx context.Context
|
||||
tokenTypeCache *UniversalCache
|
||||
jwkCache *UniversalCache
|
||||
sessionCache *UniversalCache
|
||||
introspectionCache *UniversalCache
|
||||
tokenCache *UniversalCache
|
||||
metadataCache *UniversalCache
|
||||
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
|
||||
sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
|
||||
refreshResultCache *UniversalCache // Short-lived cross-replica refresh-result dedup (paired with RefreshCoordinator)
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
cleanupStarted bool
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -169,6 +172,28 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// Initialize session invalidation cache for backchannel/front-channel logout
|
||||
// This cache stores invalidated session IDs and subjects to revoke sessions
|
||||
manager.sessionInvalidationCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeSession,
|
||||
MaxSize: 5000, // Support many concurrent invalidations
|
||||
DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// Refresh-result cache: short-lived store keyed by sha256(refreshToken).
|
||||
// In Redis-backed mode this gives cross-replica dedup of refresh grants;
|
||||
// in memory-only mode it's effectively redundant with RefreshCoordinator
|
||||
// but safe and cheap to keep.
|
||||
manager.refreshResultCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Second,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
}
|
||||
|
||||
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
|
||||
@@ -185,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
RedisPrefix: redisConfig.KeyPrefix,
|
||||
PoolSize: redisConfig.PoolSize,
|
||||
EnableMetrics: true,
|
||||
EnableTLS: redisConfig.EnableTLS,
|
||||
TLSSkipVerify: redisConfig.TLSSkipVerify,
|
||||
}
|
||||
|
||||
// Use concrete type to avoid Yaegi reflection issues with interface assignment
|
||||
@@ -349,6 +376,47 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// DCR credentials cache - CRITICAL for distributed DCR across multiple nodes
|
||||
// Uses Redis backend to share client credentials across all Traefik replicas
|
||||
manager.dcrCredentialsCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100, // Few providers expected
|
||||
DefaultTTL: 30 * 24 * time.Hour, // 30 days default (credentials are long-lived)
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("dcr"),
|
||||
)
|
||||
|
||||
// Session invalidation cache - CRITICAL for distributed backchannel/front-channel logout
|
||||
// Uses Redis backend to share session invalidations across all Traefik replicas
|
||||
manager.sessionInvalidationCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeSession,
|
||||
MaxSize: 5000, // Support many concurrent invalidations
|
||||
DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("session_invalidation"),
|
||||
)
|
||||
|
||||
// Refresh-result cache - shared via Redis so concurrent refreshes across
|
||||
// Traefik replicas can dedup their grants. The 5s TTL is long enough for
|
||||
// peers to observe a recent refresh and short enough that a stale entry
|
||||
// can't be replayed against a now-rotated refresh token.
|
||||
manager.refreshResultCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Second,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("refresh_result"),
|
||||
)
|
||||
|
||||
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
|
||||
}
|
||||
|
||||
@@ -396,6 +464,9 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
|
||||
m.sessionCache,
|
||||
m.introspectionCache,
|
||||
m.tokenTypeCache,
|
||||
m.dcrCredentialsCache,
|
||||
m.sessionInvalidationCache,
|
||||
m.refreshResultCache,
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
@@ -437,13 +508,6 @@ func (m *UniversalCacheManager) GetJWKCache() *UniversalCache {
|
||||
return m.jwkCache
|
||||
}
|
||||
|
||||
// GetSessionCache returns the session cache
|
||||
func (m *UniversalCacheManager) GetSessionCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.sessionCache
|
||||
}
|
||||
|
||||
// GetIntrospectionCache returns the token introspection cache
|
||||
func (m *UniversalCacheManager) GetIntrospectionCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
@@ -458,6 +522,28 @@ func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache {
|
||||
return m.tokenTypeCache
|
||||
}
|
||||
|
||||
// GetSessionInvalidationCache returns the session invalidation cache for backchannel/front-channel logout
|
||||
func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.sessionInvalidationCache
|
||||
}
|
||||
|
||||
// GetRefreshResultCache returns the short-lived refresh-result cache used to
|
||||
// coalesce refresh-token grants across Traefik replicas.
|
||||
func (m *UniversalCacheManager) GetRefreshResultCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.refreshResultCache
|
||||
}
|
||||
|
||||
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
|
||||
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.dcrCredentialsCache
|
||||
}
|
||||
|
||||
// Close shuts down all caches and the consolidated cleanup routine
|
||||
func (m *UniversalCacheManager) Close() error {
|
||||
// Stop the consolidated cleanup routine first
|
||||
@@ -473,7 +559,7 @@ func (m *UniversalCacheManager) Close() error {
|
||||
|
||||
// Close all caches first (they won't close the shared backend)
|
||||
for _, cache := range []*UniversalCache{
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache,
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, m.refreshResultCache,
|
||||
} {
|
||||
if cache != nil {
|
||||
_ = cache.Close() // Safe to ignore: best effort cache cleanup
|
||||
@@ -494,35 +580,6 @@ func (m *UniversalCacheManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitializeCacheManagerFromConfig initializes the cache manager with configuration
|
||||
// This should be called early in the application startup with the loaded configuration
|
||||
func InitializeCacheManagerFromConfig(config *Config) *UniversalCacheManager {
|
||||
logger := NewLogger(config.LogLevel)
|
||||
|
||||
// Initialize Redis config if not present
|
||||
if config.Redis == nil {
|
||||
config.Redis = &RedisConfig{}
|
||||
}
|
||||
|
||||
// Apply environment variable fallbacks for fields not set in config
|
||||
// This allows env vars to be used as optional overrides only when
|
||||
// the config field is not explicitly set through Traefik
|
||||
config.Redis.ApplyEnvFallbacks()
|
||||
|
||||
// Apply defaults after env fallbacks
|
||||
config.Redis.ApplyDefaults()
|
||||
|
||||
// Log cache backend selection
|
||||
if config.Redis != nil && config.Redis.Enabled {
|
||||
logger.Infof("Initializing cache backend with Redis: mode=%s, address=%s",
|
||||
config.Redis.CacheMode, config.Redis.Address)
|
||||
} else {
|
||||
logger.Info("Initializing cache backend with memory-only mode")
|
||||
}
|
||||
|
||||
return GetUniversalCacheManagerWithConfig(logger, config.Redis)
|
||||
}
|
||||
|
||||
// ResetUniversalCacheManagerForTesting resets the singleton for testing purposes only
|
||||
// This should only be called in test code to ensure proper cleanup between tests
|
||||
func ResetUniversalCacheManagerForTesting() {
|
||||
|
||||
@@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error {
|
||||
t.safeLogDebug("metadataRefreshStopChan closed")
|
||||
}
|
||||
|
||||
if t.refreshCoordinator != nil {
|
||||
t.refreshCoordinator.Shutdown()
|
||||
t.safeLogDebug("refreshCoordinator shut down")
|
||||
}
|
||||
|
||||
if t.goroutineWG != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user