mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 546ceb949c | |||
| f75b2f20e0 | |||
| cf6ed1da55 | |||
| f821b8829b | |||
| 5f9c574f95 | |||
| 7c6f09fb20 | |||
| 68e1c4319c | |||
| 17e3f8ef62 | |||
| 827926bc3a | |||
| abbfdb02a7 | |||
| 72e2b682bb | |||
| ae4ccaa89d | |||
| 984fd1c08f | |||
| 99bdd23986 | |||
| a548665edb | |||
| 8c5df82dcf | |||
| aa96e9dbee | |||
| 1e33bb0a4d | |||
| bfd702a447 | |||
| 68c150eba4 | |||
| 9cbca4c4fb | |||
| 684a990f59 | |||
| 1b6c8616fd | |||
| 4d28fa01ab | |||
| 2d1b04c637 | |||
| ccbb98b9dd | |||
| 1362cc0dac | |||
| 249dcad1b3 | |||
| de4b4d7258 | |||
| 9d52f1b018 |
@@ -0,0 +1,15 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: lukaszraczylo
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
polar: # Replace with a single Polar username
|
||||
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
||||
thanks_dev: # Replace with a single thanks.dev username
|
||||
custom: https://monzo.me/lukaszraczylo
|
||||
@@ -18,6 +18,6 @@ jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
go-version: "1.25.x"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
|
||||
@@ -19,5 +19,5 @@ jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
go-version: "1.25.x"
|
||||
secrets: inherit
|
||||
|
||||
+49
-32
@@ -14,21 +14,22 @@ linters:
|
||||
- gosec
|
||||
- misspell
|
||||
- noctx
|
||||
- nolintlint
|
||||
- prealloc
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- whitespace
|
||||
disable:
|
||||
- exhaustive
|
||||
- funlen
|
||||
- gocognit
|
||||
- gocyclo # Disabled: OAuth/OIDC flows are inherently complex
|
||||
- goprintffuncname # Disabled: naming convention is project-specific
|
||||
- lll
|
||||
- mnd
|
||||
- testpackage
|
||||
- whitespace # Disabled: style preference about newlines
|
||||
- wsl
|
||||
settings:
|
||||
dupl:
|
||||
@@ -47,29 +48,13 @@ linters:
|
||||
- fmt.Fprintln
|
||||
goconst:
|
||||
min-len: 3
|
||||
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
|
||||
min-occurrences: 15 # Increased to reduce noise for standard OAuth2/OIDC strings and common patterns like "true"
|
||||
ignore-tests: true
|
||||
gocritic:
|
||||
# Using default enabled checks in v2
|
||||
enabled-checks:
|
||||
- appendCombine
|
||||
- boolExprSimplify
|
||||
- builtinShadow
|
||||
- commentedOutCode
|
||||
- emptyFallthrough
|
||||
- equalFold
|
||||
- hexLiteral
|
||||
- indexAlloc
|
||||
- initClause
|
||||
- methodExprCall
|
||||
- nestingReduce
|
||||
- rangeExprCopy
|
||||
- rangeValCopy
|
||||
- stringXbytes
|
||||
- typeAssertChain
|
||||
- typeUnparen
|
||||
- unlabelStmt
|
||||
- yodaStyleExpr
|
||||
# Disable style-only checks that add noise
|
||||
disabled-checks:
|
||||
- ifElseChain # Style preference, switch not always clearer
|
||||
- elseif # Style preference
|
||||
gocyclo:
|
||||
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
|
||||
gosec:
|
||||
@@ -106,23 +91,23 @@ linters:
|
||||
- name: error-return
|
||||
- name: error-strings
|
||||
- name: error-naming
|
||||
- name: exported
|
||||
- name: if-return
|
||||
# - name: exported # Disabled: too noisy, not all exported functions need comments
|
||||
# - name: if-return # Disabled: style preference
|
||||
- name: increment-decrement
|
||||
- name: var-naming
|
||||
- name: var-declaration
|
||||
- name: package-comments
|
||||
# - name: var-naming # Disabled: too strict for legacy code (IP vs Ip)
|
||||
# - name: var-declaration # Disabled: explicit zero values can be clearer
|
||||
# - name: package-comments # Disabled: handled by other tools
|
||||
- name: range
|
||||
- name: receiver-naming
|
||||
- name: time-naming
|
||||
- name: unexported-return
|
||||
- name: indent-error-flow
|
||||
# - name: indent-error-flow # Disabled: style preference
|
||||
- name: errorf
|
||||
- name: empty-block
|
||||
# - name: empty-block # Disabled: sometimes empty blocks are intentional
|
||||
- name: superfluous-else
|
||||
- name: unused-parameter
|
||||
# - name: unused-parameter # Disabled: test callbacks and interface implementations often have required unused params
|
||||
- name: unreachable-code
|
||||
- name: redefines-builtin-id
|
||||
# - name: redefines-builtin-id # Disabled: min/max helpers are common before Go 1.21
|
||||
unparam:
|
||||
check-exported: false
|
||||
staticcheck:
|
||||
@@ -132,8 +117,15 @@ linters:
|
||||
- -QF1003 # Tagged switch - style preference, may affect Yaegi
|
||||
- -QF1007 # Merge conditional assignment - style preference
|
||||
- -QF1008 # Remove embedded field - may break Yaegi compatibility
|
||||
- -QF1011 # Omit type from declaration - style preference
|
||||
- -QF1012 # Use fmt.Fprintf - style preference
|
||||
- -SA9003 # Empty branch - sometimes intentional for future work
|
||||
- -ST1000 # Package comment format - not required for all packages
|
||||
- -ST1003 # Package name format - allowed for test packages
|
||||
- -ST1016 # Receiver name consistency - legacy code
|
||||
- -ST1020 # Comment format for methods - style preference
|
||||
- -ST1021 # Comment format for types - style preference
|
||||
- -ST1023 # Omit type from declaration - style preference
|
||||
exclusions:
|
||||
generated: lax
|
||||
rules:
|
||||
@@ -144,18 +136,43 @@ linters:
|
||||
- goconst
|
||||
- gocyclo
|
||||
- gosec
|
||||
- govet
|
||||
- ineffassign
|
||||
- noctx
|
||||
- prealloc
|
||||
- unparam
|
||||
- revive
|
||||
- gocritic
|
||||
path: _test\.go
|
||||
- linters:
|
||||
- dupl
|
||||
- gocyclo
|
||||
- govet
|
||||
- noctx
|
||||
- prealloc
|
||||
- unparam
|
||||
- revive
|
||||
- gocritic
|
||||
path: test.*\.go
|
||||
- linters:
|
||||
- gocritic
|
||||
- unused
|
||||
- errcheck
|
||||
- revive
|
||||
path: mocks.*\.go
|
||||
- linters:
|
||||
- errcheck
|
||||
- revive
|
||||
- gocritic
|
||||
- govet
|
||||
- unparam
|
||||
path: internal/testutil/
|
||||
- linters:
|
||||
- govet
|
||||
- unparam
|
||||
- noctx
|
||||
- prealloc
|
||||
path: integration/
|
||||
- linters:
|
||||
- gosec
|
||||
text: 'G404:'
|
||||
|
||||
+81
-1913
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,61 @@
|
||||
# traefikoidc — Makefile
|
||||
# Run `make help` for available targets.
|
||||
|
||||
GO ?= go
|
||||
GOPATH := $(shell $(GO) env GOPATH)
|
||||
# Pin to the yaegi version bundled by the deployed Traefik so yaegi-validate
|
||||
# tests the real interpreter, not a newer one that may support more. Traefik
|
||||
# v3.7.1 vendors yaegi v0.16.1 (Go ~1.22 stdlib surface). Bump when Traefik is.
|
||||
YAEGI_VERSION ?= v0.16.1
|
||||
TEST_TIMEOUT ?= 480s
|
||||
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
.PHONY: help
|
||||
help: ## Show this help
|
||||
@grep -hE '^[a-zA-Z0-9_-]+:.*## ' $(MAKEFILE_LIST) | awk 'BEGIN{FS=":.*## "}{printf " \033[36m%-16s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: build
|
||||
build: ## Compile all packages (native toolchain)
|
||||
$(GO) build ./...
|
||||
|
||||
.PHONY: fmt
|
||||
fmt: ## Format sources with gofmt
|
||||
gofmt -w $$(git ls-files '*.go' | grep -v '^vendor/')
|
||||
|
||||
.PHONY: vet
|
||||
vet: ## Run go vet
|
||||
$(GO) vet ./...
|
||||
|
||||
.PHONY: lint
|
||||
lint: ## Run golangci-lint if available
|
||||
@command -v golangci-lint >/dev/null 2>&1 && golangci-lint run ./... || echo "golangci-lint not installed; skipping"
|
||||
|
||||
.PHONY: staticcheck
|
||||
staticcheck: ## Run staticcheck (matches the CI "Static Analysis" job; catches U1000 unused, etc.)
|
||||
@command -v staticcheck >/dev/null 2>&1 || { echo ">> installing staticcheck"; $(GO) install honnef.co/go/tools/cmd/staticcheck@latest; }
|
||||
@GOFLAGS=-buildvcs=false $$(command -v staticcheck || echo "$(GOPATH)/bin/staticcheck") ./...
|
||||
|
||||
.PHONY: test
|
||||
test: ## Run the test suite
|
||||
$(GO) test ./... -count=1 -timeout $(TEST_TIMEOUT)
|
||||
|
||||
.PHONY: vendor
|
||||
vendor: ## Refresh and vendor dependencies
|
||||
$(GO) mod tidy && $(GO) mod vendor
|
||||
|
||||
# yaegi-validate interprets the plugin under the yaegi interpreter the same way
|
||||
# Traefik loads it. Native `go build`/`go test` use the standard compiler and do
|
||||
# NOT catch yaegi-only incompatibilities (unsupported stdlib symbols, reflection
|
||||
# edge cases). This target does. Importing the package forces yaegi to interpret
|
||||
# every file in it plus its vendored deps; CreateConfig + New exercise the
|
||||
# instantiation path. Pin YAEGI_VERSION to match Traefik's bundled yaegi if you
|
||||
# need exact parity.
|
||||
.PHONY: yaegi-validate
|
||||
yaegi-validate: ## Verify the plugin loads under Traefik's yaegi interpreter
|
||||
@command -v yaegi >/dev/null 2>&1 || { echo ">> installing yaegi@$(YAEGI_VERSION)"; $(GO) install github.com/traefik/yaegi/cmd/yaegi@$(YAEGI_VERSION); }
|
||||
@echo ">> interpreting plugin under yaegi (as Traefik does)"
|
||||
@DO_NOT_TRACK=1 GOFLAGS=-mod=vendor $$(command -v yaegi || echo "$(GOPATH)/bin/yaegi") run ./cmd/yaegicheck/main.go
|
||||
|
||||
.PHONY: check
|
||||
check: vet staticcheck test yaegi-validate ## vet + staticcheck + tests + yaegi load validation
|
||||
+5
-3
@@ -484,7 +484,8 @@ func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
|
||||
session.SetAccessToken(opaqueAccessToken)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokensRS(rs)
|
||||
if !authenticated || needsRefresh || expired {
|
||||
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
|
||||
authenticated, needsRefresh, expired)
|
||||
@@ -623,7 +624,8 @@ func TestAuth0Scenario2StrictMode(t *testing.T) {
|
||||
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
|
||||
|
||||
// In strict mode, this should FAIL (no fallback to ID token)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokensRS(rs)
|
||||
if authenticated {
|
||||
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
|
||||
}
|
||||
@@ -1491,7 +1493,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetUserIdentifier("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
|
||||
+69
-34
@@ -4,8 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
@@ -44,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
@@ -77,7 +76,12 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := uuid.NewString()
|
||||
csrfToken, err := newUUIDv4()
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to generate CSRF token: %v", err)
|
||||
http.Error(rw, "Failed to generate CSRF token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to generate nonce: %v", err)
|
||||
@@ -178,6 +182,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
}
|
||||
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
if t.enablePKCE && codeVerifier == "" {
|
||||
t.logger.Error("PKCE is enabled but code verifier is missing from session during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: PKCE verifier missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
@@ -246,7 +255,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
@@ -259,7 +268,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
// Neutralize open-redirect payloads (e.g. //evil.com, /\evil.com) stored
|
||||
// from the original request target before using it as the post-login
|
||||
// redirect target. normalizeLogoutPath forces a host-relative path.
|
||||
redirectPath = normalizeLogoutPath(incomingPath)
|
||||
}
|
||||
session.SetIncomingPath("")
|
||||
|
||||
@@ -286,7 +298,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens to prevent replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -301,28 +313,6 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// isUserAuthenticated determines the authentication status and refresh requirements.
|
||||
// It delegates to provider-specific validation methods that handle different token types
|
||||
// and expiration behaviors.
|
||||
// Parameters:
|
||||
// - session: The session data containing authentication tokens.
|
||||
//
|
||||
// Returns:
|
||||
// - authenticated (bool): True if the user has valid tokens.
|
||||
// - needsRefresh (bool): True if tokens are valid but nearing expiration.
|
||||
// - expired (bool): True if the session is unauthenticated, the token is missing,
|
||||
// or the token verification failed for reasons other than nearing/actual expiration.
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if t.isAzureProvider() {
|
||||
return t.validateAzureTokens(session)
|
||||
} else if t.isGoogleProvider() {
|
||||
return t.validateGoogleTokens(session)
|
||||
}
|
||||
// Auth0 and other providers can now use standard validation
|
||||
// which handles opaque tokens generically
|
||||
return t.validateStandardTokens(session)
|
||||
}
|
||||
|
||||
// isAjaxRequest determines if this is an AJAX request that should receive 401 instead of redirect
|
||||
func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
|
||||
xhr := req.Header.Get("X-Requested-With")
|
||||
@@ -334,9 +324,54 @@ func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
|
||||
strings.Contains(accept, "application/json")
|
||||
}
|
||||
|
||||
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
|
||||
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
||||
// This is a heuristic check - actual implementation would depend on
|
||||
// the specific provider and token metadata
|
||||
return false // Placeholder implementation
|
||||
// isNonNavigationRequest reports whether the request is a browser
|
||||
// sub-resource (script, image, stylesheet, fetch, serviceWorker) rather than
|
||||
// a top-level HTML navigation. Non-navigation requests MUST NOT trigger an
|
||||
// OIDC redirect flow: several sub-resource loads happening in parallel would
|
||||
// each call defaultInitiateAuthentication, each overwriting the session's
|
||||
// CSRF/nonce, breaking the eventual callback (issue #129).
|
||||
//
|
||||
// Detection prefers Sec-Fetch-Mode, which all modern browsers send
|
||||
// (Chrome/Edge/Firefox/Safari). For older or non-browser clients we fall
|
||||
// back to Accept: if Accept is present and does not list text/html, treat
|
||||
// it as a sub-resource. An empty/missing Accept is assumed to be navigation
|
||||
// (safer to redirect than 401 on an ambiguous request).
|
||||
func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
|
||||
if mode := req.Header.Get("Sec-Fetch-Mode"); mode != "" {
|
||||
return mode != "navigate"
|
||||
}
|
||||
accept := req.Header.Get("Accept")
|
||||
if accept == "" || accept == "*/*" {
|
||||
return false
|
||||
}
|
||||
return !strings.Contains(accept, "text/html")
|
||||
}
|
||||
|
||||
// isRefreshTokenExpired checks whether the stored refresh token is likely
|
||||
// past its useful lifetime, using the cookie-side issued_at timestamp set by
|
||||
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
|
||||
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
|
||||
// MaxRefreshTokenAgeSeconds; 0 disables the check).
|
||||
//
|
||||
// The point of this check is to short-circuit the refresh path BEFORE the
|
||||
// thundering herd hits the IdP for a token the provider has almost certainly
|
||||
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
|
||||
// style polling clients from looping on invalid_grant after a long pause.
|
||||
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
||||
if t == nil || session == nil {
|
||||
return false
|
||||
}
|
||||
if t.maxRefreshTokenAge <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
issuedAt := session.GetRefreshTokenIssuedAt()
|
||||
if issuedAt.IsZero() {
|
||||
// No timestamp recorded (legacy session pre-dating the issued_at
|
||||
// field). Don't force a re-auth - attempt refresh once and let the
|
||||
// IdP be the source of truth.
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(issuedAt) > t.maxRefreshTokenAge
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Pre-populate session with old data
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-access-token-with-many-characters")
|
||||
session.SetRefreshToken("old-refresh-token-with-many-characters")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Verify old data is cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
|
||||
// Verify new data is set
|
||||
s.Equal(csrfToken, session.GetCSRF())
|
||||
@@ -305,6 +305,90 @@ func (s *AuthFlowBehaviourSuite) TestIsAjaxRequest() {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsNonNavigationRequest verifies browser sub-resource detection used to
|
||||
// suppress OIDC redirects on parallel static-asset loads (issue #129).
|
||||
func (s *AuthFlowBehaviourSuite) TestIsNonNavigationRequest() {
|
||||
testCases := []struct {
|
||||
headers map[string]string
|
||||
name string
|
||||
expectNonNavigation bool
|
||||
}{
|
||||
{
|
||||
name: "Sec-Fetch-Mode navigate",
|
||||
headers: map[string]string{"Sec-Fetch-Mode": "navigate"},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode no-cors",
|
||||
headers: map[string]string{"Sec-Fetch-Mode": "no-cors"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode cors",
|
||||
headers: map[string]string{"Sec-Fetch-Mode": "cors"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode same-origin (fetch in page)",
|
||||
headers: map[string]string{"Sec-Fetch-Mode": "same-origin"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Accept text/html (fallback)",
|
||||
headers: map[string]string{"Accept": "text/html,application/xhtml+xml"},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "Accept image/png (fallback)",
|
||||
headers: map[string]string{"Accept": "image/png,image/*;q=0.8"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Accept application/javascript (fallback)",
|
||||
headers: map[string]string{"Accept": "application/javascript"},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
{
|
||||
name: "Accept */* treated as navigation",
|
||||
headers: map[string]string{"Accept": "*/*"},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "No Accept header assumed navigation",
|
||||
headers: map[string]string{},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode beats Accept (navigate wins)",
|
||||
headers: map[string]string{
|
||||
"Sec-Fetch-Mode": "navigate",
|
||||
"Accept": "application/javascript",
|
||||
},
|
||||
expectNonNavigation: false,
|
||||
},
|
||||
{
|
||||
name: "Sec-Fetch-Mode beats Accept (no-cors wins)",
|
||||
headers: map[string]string{
|
||||
"Sec-Fetch-Mode": "no-cors",
|
||||
"Accept": "text/html",
|
||||
},
|
||||
expectNonNavigation: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/_static/asset.js", nil)
|
||||
for key, value := range tc.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
result := s.tOidc.isNonNavigationRequest(req)
|
||||
s.Equal(tc.expectNonNavigation, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleCallback_MissingState tests callback with missing state parameter
|
||||
func (s *AuthFlowBehaviourSuite) TestHandleCallback_MissingState() {
|
||||
sessionManager, err := NewSessionManager(
|
||||
@@ -627,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
session, err := sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
session.mainSession.Values["redirect_count"] = 3
|
||||
|
||||
@@ -636,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
|
||||
// Session should be cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
s.Empty(session.GetIDToken())
|
||||
|
||||
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
|
||||
|
||||
+4
-3
@@ -599,8 +599,9 @@ func GetGlobalTaskMemoryMonitor(logger *Logger) *TaskMemoryMonitor {
|
||||
return globalTaskMemoryMonitor
|
||||
}
|
||||
|
||||
// NewTaskMemoryMonitor creates a new memory monitor for task registry
|
||||
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior
|
||||
// NewTaskMemoryMonitor creates a new memory monitor for task registry.
|
||||
//
|
||||
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior.
|
||||
func NewTaskMemoryMonitor(logger *Logger, registry *TaskRegistry) *TaskMemoryMonitor {
|
||||
return GetGlobalTaskMemoryMonitor(logger)
|
||||
}
|
||||
@@ -712,7 +713,7 @@ func (mm *TaskMemoryMonitor) checkForMemoryIssues(stats TaskMemoryStats) {
|
||||
|
||||
// Check for goroutine leaks (arbitrary threshold)
|
||||
if stats.Goroutines > 100 {
|
||||
mm.logger.Infof("High goroutine count detected: %d", stats.Goroutines)
|
||||
mm.logger.Debugf("High goroutine count detected: %d", stats.Goroutines)
|
||||
}
|
||||
|
||||
// Check for heap growth without corresponding GC activity
|
||||
|
||||
+8
-4
@@ -262,7 +262,8 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
|
||||
|
||||
// Test that CSRF is preserved during Azure validation failures
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokensRS(rs)
|
||||
|
||||
// Should not be authenticated due to validation failure
|
||||
if authenticated {
|
||||
@@ -453,7 +454,8 @@ func TestValidateGoogleTokens(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session := tt.setupSession()
|
||||
|
||||
auth, refresh, expired := ts.tOidc.validateGoogleTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
auth, refresh, expired := ts.tOidc.validateGoogleTokensRS(rs)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
@@ -637,7 +639,8 @@ func TestIsUserAuthenticated(t *testing.T) {
|
||||
defer func() { ts.tOidc.issuerURL = originalIssuer }()
|
||||
|
||||
session := tt.setupSession()
|
||||
auth, refresh, expired := ts.tOidc.isUserAuthenticated(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
auth, refresh, expired := ts.tOidc.isUserAuthenticatedRS(rs)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
@@ -762,7 +765,8 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session := tt.setupSession()
|
||||
|
||||
auth, refresh, expired := ts.tOidc.validateAzureTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
auth, refresh, expired := ts.tOidc.validateAzureTokensRS(rs)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
|
||||
@@ -29,8 +29,9 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
assert.Equal(t, MemoryPressureNone, pressure)
|
||||
|
||||
// Collect stats to populate lastStats
|
||||
monitor.GetCurrentStats()
|
||||
// Explicitly sample to populate lastStats; GetCurrentStats is now a
|
||||
// cached read and no longer forces a runtime.ReadMemStats.
|
||||
monitor.Refresh()
|
||||
|
||||
// Now should return a valid pressure level
|
||||
pressure = monitor.GetMemoryPressure()
|
||||
@@ -46,11 +47,13 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Start monitoring should not panic
|
||||
// Start monitoring should not panic. Interval is clamped to the
|
||||
// minimum (30s); we rely on Refresh() when we need a synchronous
|
||||
// sample instead of waiting for a tick.
|
||||
assert.NotPanics(t, func() {
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 100*time.Millisecond)
|
||||
time.Sleep(GetTestDuration(50 * time.Millisecond))
|
||||
monitor.StartMonitoring(ctx, 0)
|
||||
monitor.Refresh()
|
||||
})
|
||||
|
||||
// Clean up
|
||||
@@ -117,6 +120,9 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Refresh forces a synchronous sample; GetCurrentStats is a cached
|
||||
// read, so we sample first to guarantee fresh data.
|
||||
monitor.Refresh()
|
||||
stats := monitor.GetCurrentStats()
|
||||
assert.NotNil(t, stats)
|
||||
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
|
||||
@@ -450,12 +456,12 @@ func TestMemoryMonitorIntegration(t *testing.T) {
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
defer monitor.StopMonitoring()
|
||||
|
||||
// Start monitoring
|
||||
// Start monitoring. The interval is clamped to the minimum (30s) so
|
||||
// the ticker won't fire during the test; drive the sample manually via
|
||||
// Refresh() instead.
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 50*time.Millisecond)
|
||||
|
||||
// Wait for at least one check
|
||||
time.Sleep(GetTestDuration(150 * time.Millisecond))
|
||||
monitor.StartMonitoring(ctx, 0)
|
||||
monitor.Refresh()
|
||||
|
||||
// Get pressure (should be a valid pressure level)
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
@@ -488,6 +494,7 @@ func TestMemoryStatsCollection(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
monitor.Refresh()
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
assert.NotNil(t, stats)
|
||||
@@ -501,6 +508,7 @@ func TestMemoryStatsCollection(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
monitor.Refresh()
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
// Should calculate and include pressure level
|
||||
@@ -521,13 +529,14 @@ func TestMemoryStatsCollection(t *testing.T) {
|
||||
// Allocate some memory
|
||||
_ = make([]byte, 1024*1024) // 1MB
|
||||
|
||||
// Get stats before GC
|
||||
beforeStats := monitor.GetCurrentStats()
|
||||
// Get stats before GC (explicit Refresh so we have a fresh pre-GC
|
||||
// snapshot to compare against, not the constructor baseline).
|
||||
beforeStats := monitor.Refresh()
|
||||
|
||||
// Trigger GC
|
||||
// Trigger GC (internally Refresh()es before and after)
|
||||
monitor.TriggerGC()
|
||||
|
||||
// Get stats after GC
|
||||
// Get stats after GC from cache (TriggerGC already refreshed it)
|
||||
afterStats := monitor.GetCurrentStats()
|
||||
|
||||
// After GC should have different stats
|
||||
|
||||
+683
@@ -0,0 +1,683 @@
|
||||
// Package traefikoidc — bearer-token (M2M) authentication path.
|
||||
//
|
||||
// Disabled by default. When enabled via Config.EnableBearerAuth, requests
|
||||
// presenting "Authorization: Bearer <jwt>" are validated against the
|
||||
// configured OIDC provider (signature, issuer, audience, exp, replay-Get)
|
||||
// and the request is forwarded downstream without creating a cookie session.
|
||||
//
|
||||
// Design rules (kept here in code as the single source of truth):
|
||||
// - Access tokens only. ID tokens are rejected via detectTokenType.
|
||||
// - Audience is mandatory (enforced at startup in main.go).
|
||||
// - alg + kid pinned BEFORE JWKS fetch to deny amplification probes.
|
||||
// - iat upper-age cap bounds clock-skew / forever-token abuse.
|
||||
// - Multi-audience tokens require matching azp.
|
||||
// - Per-IP 401 throttle returns 429 + Retry-After after a threshold.
|
||||
// - JTI Set is suppressed (skipReplayMarking) but JTI Get stays — revoked
|
||||
// tokens (RevokeToken adds to blacklist) are still rejected.
|
||||
// - Identifier is read from BearerIdentifierClaim (default "sub"), never
|
||||
// from UserIdentifierClaim, to avoid the unverified-email spoofing path.
|
||||
// - Identifier is sanitized: length cap, control chars, bidi-override,
|
||||
// delimiter chars (, ; =) rejected.
|
||||
// - On excluded URLs the Authorization header is stripped before forwarding.
|
||||
//
|
||||
// See docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md and
|
||||
// docs/BEARER_AUTH.md for the full threat model.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const bearerPrefix = "Bearer "
|
||||
|
||||
// bearerAlgAllowlist is the set of JWS algorithms accepted on the bearer
|
||||
// path. Asymmetric-only — HS* would allow public-key-as-HMAC-secret attacks
|
||||
// if any operator ever rotates a key into the symmetric branch by mistake;
|
||||
// "none" is obvious. Matches the allowlist enforced inside jwt.Verify but is
|
||||
// checked here BEFORE the JWKS fetch so attacker noise can't amplify.
|
||||
var bearerAlgAllowlist = map[string]struct{}{
|
||||
"RS256": {}, "RS384": {}, "RS512": {},
|
||||
"PS256": {}, "PS384": {}, "PS512": {},
|
||||
"ES256": {}, "ES384": {}, "ES512": {},
|
||||
}
|
||||
|
||||
// bearerKidMaxLen caps the JOSE kid header length to keep memory and cache-key
|
||||
// usage bounded against attacker-controlled values.
|
||||
const bearerKidMaxLen = 256
|
||||
|
||||
// validKidChar is the allowlist for kid header characters. Letters, digits,
|
||||
// dot, underscore, hyphen, equals. Intentionally narrow; real-world kid
|
||||
// values are short URL-safe-base64-ish identifiers.
|
||||
func validKidChar(r rune) bool {
|
||||
if r >= 'a' && r <= 'z' {
|
||||
return true
|
||||
}
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
return true
|
||||
}
|
||||
if r >= '0' && r <= '9' {
|
||||
return true
|
||||
}
|
||||
switch r {
|
||||
case '.', '_', '-', '=':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// bearerError categorizes failure modes for the response builder. Categories
|
||||
// map 1:1 to the table in docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md
|
||||
// §9 so behavior is auditable from spec to code.
|
||||
type bearerErrorKind int
|
||||
|
||||
const (
|
||||
bearerErrInvalidRequest bearerErrorKind = iota
|
||||
bearerErrInvalidToken
|
||||
bearerErrTokenInactive
|
||||
bearerErrInvalidIdentifier
|
||||
bearerErrForbidden
|
||||
bearerErrThrottled
|
||||
bearerErrIntrospectionUnavailable
|
||||
)
|
||||
|
||||
type bearerError struct {
|
||||
kind bearerErrorKind
|
||||
reason string
|
||||
}
|
||||
|
||||
func (e *bearerError) Error() string { return e.reason }
|
||||
|
||||
func newBearerError(kind bearerErrorKind, reason string) *bearerError {
|
||||
return &bearerError{kind: kind, reason: reason}
|
||||
}
|
||||
|
||||
// joseHeader is the minimal subset of the JWS protected header we inspect
|
||||
// BEFORE running the full verification pipeline. Lifted out so the alg+kid
|
||||
// pin can run without paying for parseJWT's full claim decode.
|
||||
type joseHeader struct {
|
||||
Alg string `json:"alg"`
|
||||
Kid string `json:"kid"`
|
||||
Typ string `json:"typ"`
|
||||
}
|
||||
|
||||
// parseBearerJOSEHeader decodes the first JWT segment for early alg/kid pinning.
|
||||
// Does not touch the payload or signature — those are the verifier's job.
|
||||
// Returns nil on success; *bearerError on rejection so the handler can map
|
||||
// directly to a status code. The decoded header itself is not surfaced because
|
||||
// callers don't need it (verifyTokenWithOpts re-parses internally).
|
||||
func parseBearerJOSEHeader(token string) *bearerError {
|
||||
dot := strings.IndexByte(token, '.')
|
||||
if dot <= 0 {
|
||||
return newBearerError(bearerErrInvalidToken, "malformed JWT: no header segment")
|
||||
}
|
||||
raw, err := base64.RawURLEncoding.DecodeString(token[:dot])
|
||||
if err != nil {
|
||||
// Some IdPs pad with '='; tolerate by retrying with StdEncoding.
|
||||
raw, err = base64.URLEncoding.DecodeString(token[:dot])
|
||||
if err != nil {
|
||||
return newBearerError(bearerErrInvalidToken, "malformed JWT: header not base64url")
|
||||
}
|
||||
}
|
||||
var hdr joseHeader
|
||||
if err := json.Unmarshal(raw, &hdr); err != nil {
|
||||
return newBearerError(bearerErrInvalidToken, "malformed JWT: header not JSON")
|
||||
}
|
||||
if _, ok := bearerAlgAllowlist[hdr.Alg]; !ok {
|
||||
return newBearerError(bearerErrInvalidToken, fmt.Sprintf("disallowed alg %q on bearer path", hdr.Alg))
|
||||
}
|
||||
if hdr.Kid == "" {
|
||||
return newBearerError(bearerErrInvalidToken, "missing kid header")
|
||||
}
|
||||
if len(hdr.Kid) > bearerKidMaxLen {
|
||||
return newBearerError(bearerErrInvalidToken, "kid header exceeds max length")
|
||||
}
|
||||
for _, r := range hdr.Kid {
|
||||
if !validKidChar(r) {
|
||||
return newBearerError(bearerErrInvalidToken, "kid header contains disallowed characters")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerClaimRuneReason reports why a rune is unsafe to inject into a request
|
||||
// header value, or "" if the rune is acceptable. Shared core of the bearer-path
|
||||
// identifier sanitizer and the cookie-path header claim sanitizer: rejects
|
||||
// control chars (CRLF/header injection), Unicode bidi-override runes (RTL
|
||||
// spoofing of admin UI / SIEM), and the delimiters , ; = (a comma in a group
|
||||
// name would inject extra entries into a comma-joined header).
|
||||
func headerClaimRuneReason(r rune) string {
|
||||
if reason := headerInjectionRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
// The , ; = delimiters are only unsafe for values placed into delimited or
|
||||
// list contexts (a comma-joined header, or an identifier downstreams may
|
||||
// split). They are valid in arbitrary single header values, so this stricter
|
||||
// check is used for the cookie-path identifier and the group/role list, NOT
|
||||
// for free-form templated header output (see headerValueReason).
|
||||
if r == ',' || r == ';' || r == '=' {
|
||||
return "delimiter character"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerInjectionRuneReason reports why a rune is unsafe in ANY HTTP header
|
||||
// value, or "" if acceptable. Rejects control characters (CR/LF header
|
||||
// injection) and Unicode bidi-override runes (RTL spoofing of admin UIs/SIEMs).
|
||||
// Unlike headerClaimRuneReason it does NOT reject , ; = which are legitimate in
|
||||
// free-form header values (e.g. an opaque "Authorization: Bearer <token>").
|
||||
func headerInjectionRuneReason(r rune) string {
|
||||
if unicode.IsControl(r) {
|
||||
return "control character"
|
||||
}
|
||||
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
|
||||
return "bidi-override character"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerValueReason reports why value is unsafe to forward as a free-form HTTP
|
||||
// header value, or "" if acceptable. It rejects values over maxLen (maxLen<=0
|
||||
// disables the check) and values containing control or bidi-override runes, but
|
||||
// permits , ; = (valid in header values). Empty is allowed. The reason string
|
||||
// never includes the value, so it is safe to log.
|
||||
func headerValueReason(value string, maxLen int) string {
|
||||
if maxLen > 0 && len(value) > maxLen {
|
||||
return "exceeds max length"
|
||||
}
|
||||
for _, r := range value {
|
||||
if reason := headerInjectionRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerClaimValueReason reports why value is unsafe to inject into a
|
||||
// downstream request header, or "" if it is acceptable. It rejects empty
|
||||
// values, values exceeding maxLen (maxLen<=0 disables the length check), and
|
||||
// values containing any rune rejected by headerClaimRuneReason. The reason
|
||||
// string is safe to log (it never includes the value itself).
|
||||
func headerClaimValueReason(value string, maxLen int) string {
|
||||
if value == "" {
|
||||
return "empty value"
|
||||
}
|
||||
if maxLen > 0 && len(value) > maxLen {
|
||||
return "exceeds max length"
|
||||
}
|
||||
for _, r := range value {
|
||||
if reason := headerClaimRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sanitizeHeaderClaimValue validates a claim-derived value before it is
|
||||
// injected into a downstream request header. It trims surrounding whitespace
|
||||
// and fails closed (ok=false) on empty values, values exceeding maxLen
|
||||
// (maxLen<=0 disables the length check), or values containing any rune rejected
|
||||
// by headerClaimRuneReason. Used by the cookie/session path, which — unlike the
|
||||
// bearer path — does not otherwise sanitize the principal identifier or the
|
||||
// group/role strings joined into X-User-Groups / X-User-Roles.
|
||||
func sanitizeHeaderClaimValue(raw string, maxLen int) (string, bool) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if headerClaimValueReason(value, maxLen) != "" {
|
||||
return "", false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
// sanitizeBearerIdentifier validates and trims a principal identifier before
|
||||
// it is injected into request headers. Layered defense: net/http will reject
|
||||
// CRLF on the wire too, but rejecting early gives clearer error logs and
|
||||
// prevents bidi-override / delimiter chars that pass net/http's narrower
|
||||
// checks but confuse downstream parsers and admin UIs.
|
||||
func sanitizeBearerIdentifier(raw string, maxLen int) (string, *bearerError) {
|
||||
identifier := strings.TrimSpace(raw)
|
||||
if identifier == "" {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier claim empty")
|
||||
}
|
||||
if maxLen > 0 && len(identifier) > maxLen {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier exceeds max length")
|
||||
}
|
||||
for _, r := range identifier {
|
||||
if reason := headerClaimRuneReason(r); reason != "" {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains "+reason)
|
||||
}
|
||||
}
|
||||
return identifier, nil
|
||||
}
|
||||
|
||||
// resolveBearerIdentifier picks the principal identifier from claims using
|
||||
// the configured BearerIdentifierClaim (default "sub"). Decoupled from
|
||||
// userIdentifierClaim (cookie path) to avoid the unverified-email spoofing
|
||||
// vector documented in the spec §13.
|
||||
func resolveBearerIdentifier(claims map[string]interface{}, claimName string) (string, *bearerError) {
|
||||
if claimName == "" {
|
||||
claimName = "sub"
|
||||
}
|
||||
raw, ok := claims[claimName]
|
||||
if !ok {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, fmt.Sprintf("missing claim %q", claimName))
|
||||
}
|
||||
str, ok := raw.(string)
|
||||
if !ok {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, fmt.Sprintf("claim %q not a string", claimName))
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// enforceMultiAudienceAzp implements the spec hardening: when aud is a
|
||||
// multi-element array, require an azp claim equal to clientID. Single-string
|
||||
// aud is unaffected (existing verifyAudience handles it).
|
||||
func enforceMultiAudienceAzp(claims map[string]interface{}, clientID string) *bearerError {
|
||||
audRaw, ok := claims["aud"]
|
||||
if !ok {
|
||||
return nil // verifyToken already rejects missing aud
|
||||
}
|
||||
arr, ok := audRaw.([]interface{})
|
||||
if !ok {
|
||||
return nil // single-string aud
|
||||
}
|
||||
if len(arr) <= 1 {
|
||||
return nil
|
||||
}
|
||||
azpRaw, ok := claims["azp"]
|
||||
if !ok {
|
||||
return newBearerError(bearerErrInvalidToken, "multi-audience token missing azp")
|
||||
}
|
||||
azp, ok := azpRaw.(string)
|
||||
if !ok || azp == "" {
|
||||
return newBearerError(bearerErrInvalidToken, "multi-audience token has empty/non-string azp")
|
||||
}
|
||||
if azp != clientID {
|
||||
return newBearerError(bearerErrInvalidToken, "multi-audience token azp does not match clientID")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// enforceIatAge implements the spec MaxTokenAgeSeconds bound on iat. Bounds
|
||||
// clock-manipulation / forever-token abuse without rejecting tokens with a
|
||||
// normal iat just because the issuer's clock skews a few seconds.
|
||||
func enforceIatAge(claims map[string]interface{}, maxAge time.Duration) *bearerError {
|
||||
if maxAge <= 0 {
|
||||
return nil
|
||||
}
|
||||
iatRaw, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
// jwt.Verify already requires iat; this branch shouldn't be reached.
|
||||
return newBearerError(bearerErrInvalidToken, "missing iat claim")
|
||||
}
|
||||
iat := time.Unix(int64(iatRaw), 0)
|
||||
if time.Since(iat) > maxAge {
|
||||
return newBearerError(bearerErrInvalidToken, "token iat outside age bound")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hashIdentifierForLog returns a short SHA-256 prefix safe for info-level
|
||||
// logs. Full identifier is only emitted at debug. Satisfies the audit
|
||||
// requirement (trace which principal was rejected) without leaking PII.
|
||||
func hashIdentifierForLog(identifier string) string {
|
||||
if identifier == "" {
|
||||
return "(none)"
|
||||
}
|
||||
sum := sha256.Sum256([]byte(identifier))
|
||||
return hex.EncodeToString(sum[:4]) // 8 hex chars
|
||||
}
|
||||
|
||||
// --- Per-IP failure throttle ---
|
||||
|
||||
// bearerFailureTracker records consecutive bearer-auth 401s per source IP and
|
||||
// parks repeat offenders in a 429 penalty box. Limits offline-guessing-style
|
||||
// attacks and protects the shared rate-limiter / JWKS endpoint from being
|
||||
// burned by a single source.
|
||||
type bearerFailureTracker struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*bearerFailureEntry
|
||||
// Configuration snapshot. Captured at construction so a hot reconfigure
|
||||
// doesn't race with the per-request paths.
|
||||
threshold int
|
||||
window time.Duration
|
||||
penalty time.Duration
|
||||
}
|
||||
|
||||
type bearerFailureEntry struct {
|
||||
firstFailureAt time.Time
|
||||
penaltyUntil time.Time
|
||||
count int
|
||||
}
|
||||
|
||||
func newBearerFailureTracker(threshold int, window, penalty time.Duration) *bearerFailureTracker {
|
||||
if threshold <= 0 {
|
||||
threshold = 20
|
||||
}
|
||||
if window <= 0 {
|
||||
window = 60 * time.Second
|
||||
}
|
||||
if penalty <= 0 {
|
||||
penalty = 60 * time.Second
|
||||
}
|
||||
return &bearerFailureTracker{
|
||||
entries: make(map[string]*bearerFailureEntry),
|
||||
threshold: threshold,
|
||||
window: window,
|
||||
penalty: penalty,
|
||||
}
|
||||
}
|
||||
|
||||
// blocked reports whether the source IP is currently in the penalty box.
|
||||
// Returns (true, retryAfter) when blocked; (false, 0) when allowed.
|
||||
func (b *bearerFailureTracker) blocked(ip string) (bool, time.Duration) {
|
||||
if b == nil || ip == "" {
|
||||
return false, 0
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
e, ok := b.entries[ip]
|
||||
if !ok {
|
||||
return false, 0
|
||||
}
|
||||
now := time.Now()
|
||||
if !e.penaltyUntil.IsZero() && now.Before(e.penaltyUntil) {
|
||||
return true, time.Until(e.penaltyUntil)
|
||||
}
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// recordFailure increments the failure counter for the given IP and trips
|
||||
// the penalty box once threshold-within-window is exceeded.
|
||||
func (b *bearerFailureTracker) recordFailure(ip string) {
|
||||
if b == nil || ip == "" {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
now := time.Now()
|
||||
e, ok := b.entries[ip]
|
||||
if !ok || now.Sub(e.firstFailureAt) > b.window {
|
||||
e = &bearerFailureEntry{firstFailureAt: now}
|
||||
b.entries[ip] = e
|
||||
}
|
||||
e.count++
|
||||
if e.count >= b.threshold {
|
||||
e.penaltyUntil = now.Add(b.penalty)
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess clears the failure counter for the given IP after a
|
||||
// successful bearer auth.
|
||||
func (b *bearerFailureTracker) recordSuccess(ip string) {
|
||||
if b == nil || ip == "" {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
e, ok := b.entries[ip]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Preserve an active penalty so a single success cannot wipe an in-effect
|
||||
// lockout; only reset the counter when no penalty is active or it has expired.
|
||||
now := time.Now()
|
||||
if e.penaltyUntil.IsZero() || now.After(e.penaltyUntil) {
|
||||
e.count = 0
|
||||
e.firstFailureAt = now
|
||||
}
|
||||
}
|
||||
|
||||
// clientIPForBearer returns the source IP used to key the failure tracker.
|
||||
// Trusts only the request's transport-level RemoteAddr; X-Forwarded-For is
|
||||
// intentionally ignored to avoid attacker-controlled key spoofing. Behind a
|
||||
// trusted reverse proxy where every request shares one IP, the throttle is
|
||||
// still useful (caps attacker churn through that proxy) — operators wanting
|
||||
// per-real-client throttling must terminate at this middleware.
|
||||
func clientIPForBearer(req *http.Request) string {
|
||||
if req == nil {
|
||||
return ""
|
||||
}
|
||||
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
return req.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// --- Bearer auth entrypoint ---
|
||||
|
||||
// detectBearerToken returns (token, true) when the request carries a usable
|
||||
// Authorization: Bearer header. Case-insensitive on the scheme. Returns
|
||||
// ("", false) for any other shape.
|
||||
func detectBearerToken(req *http.Request) (string, bool) {
|
||||
if req == nil {
|
||||
return "", false
|
||||
}
|
||||
h := req.Header.Get("Authorization")
|
||||
if len(h) < len(bearerPrefix) {
|
||||
return "", false
|
||||
}
|
||||
if !strings.EqualFold(h[:len(bearerPrefix)], bearerPrefix) {
|
||||
return "", false
|
||||
}
|
||||
token := strings.TrimSpace(h[len(bearerPrefix):])
|
||||
if token == "" {
|
||||
return "", false
|
||||
}
|
||||
return token, true
|
||||
}
|
||||
|
||||
// hasSessionCookie reports whether the request carries any cookie matching
|
||||
// the session prefix. Used to implement the cookie-wins-by-default
|
||||
// precedence rule when both bearer and cookie are present.
|
||||
func (t *TraefikOidc) hasSessionCookie(req *http.Request) bool {
|
||||
if t.sessionManager == nil {
|
||||
return false
|
||||
}
|
||||
prefix := t.sessionManager.GetCookiePrefix()
|
||||
if prefix == "" {
|
||||
return false
|
||||
}
|
||||
for _, c := range req.Cookies() {
|
||||
if strings.HasPrefix(c.Name, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeBearerError writes the canonical 401/403/429/503 response per spec §9.
|
||||
// Body is always generic; reason is logged at debug only. The
|
||||
// WWW-Authenticate hint is gated by config (default on, RFC 6750 compliant).
|
||||
func (t *TraefikOidc) writeBearerError(rw http.ResponseWriter, req *http.Request, err *bearerError) {
|
||||
var (
|
||||
status int
|
||||
errCode string
|
||||
body string
|
||||
retryAfter time.Duration
|
||||
)
|
||||
switch err.kind {
|
||||
case bearerErrInvalidRequest:
|
||||
status = http.StatusUnauthorized
|
||||
errCode = "invalid_request"
|
||||
body = "Unauthorized"
|
||||
case bearerErrInvalidToken, bearerErrTokenInactive, bearerErrInvalidIdentifier:
|
||||
status = http.StatusUnauthorized
|
||||
errCode = "invalid_token"
|
||||
body = "Unauthorized"
|
||||
case bearerErrForbidden:
|
||||
status = http.StatusForbidden
|
||||
body = "Access denied"
|
||||
case bearerErrThrottled:
|
||||
status = http.StatusTooManyRequests
|
||||
body = "Too Many Requests"
|
||||
retryAfter = t.bearerFailurePenalty
|
||||
case bearerErrIntrospectionUnavailable:
|
||||
status = http.StatusServiceUnavailable
|
||||
body = "Service Unavailable"
|
||||
default:
|
||||
status = http.StatusUnauthorized
|
||||
body = "Unauthorized"
|
||||
}
|
||||
|
||||
if t.bearerEmitWWWAuthenticate && errCode != "" {
|
||||
rw.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer error=%q`, errCode))
|
||||
}
|
||||
if retryAfter > 0 {
|
||||
rw.Header().Set("Retry-After", fmt.Sprintf("%d", int(retryAfter.Seconds())))
|
||||
}
|
||||
rw.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
rw.WriteHeader(status)
|
||||
_, _ = rw.Write([]byte(body)) // Safe to ignore: best-effort error body write
|
||||
|
||||
if t.logger != nil {
|
||||
t.logger.Debugf("bearer auth rejected: status=%d category=%v reason=%q path=%s",
|
||||
status, err.kind, err.reason, req.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// handleBearerRequest is the entry point invoked by ServeHTTP when the
|
||||
// EnableBearerAuth flag is set, the request carries an Authorization: Bearer
|
||||
// header, and the (configurable) cookie-precedence rule allows the bearer
|
||||
// path to run.
|
||||
func (t *TraefikOidc) handleBearerRequest(rw http.ResponseWriter, req *http.Request) {
|
||||
ip := clientIPForBearer(req)
|
||||
|
||||
if blocked, retryAfter := t.bearerFailureTracker.blocked(ip); blocked {
|
||||
throttled := newBearerError(bearerErrThrottled, "ip in penalty box")
|
||||
// Preserve the actual retry-after even if it diverged from the
|
||||
// configured default (clock-skew, partial-window expiry).
|
||||
if retryAfter > 0 {
|
||||
rw.Header().Set("Retry-After", fmt.Sprintf("%d", int(retryAfter.Seconds())))
|
||||
}
|
||||
t.writeBearerError(rw, req, throttled)
|
||||
return
|
||||
}
|
||||
|
||||
token, ok := detectBearerToken(req)
|
||||
if !ok {
|
||||
t.bearerFailureTracker.recordFailure(ip)
|
||||
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidRequest, "missing or empty bearer token"))
|
||||
return
|
||||
}
|
||||
if len(token) > AccessTokenConfig.MaxLength {
|
||||
t.bearerFailureTracker.recordFailure(ip)
|
||||
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidToken, "token exceeds max length"))
|
||||
return
|
||||
}
|
||||
if strings.Count(token, ".") != 2 {
|
||||
t.bearerFailureTracker.recordFailure(ip)
|
||||
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidToken, "token is not a 3-segment JWT"))
|
||||
return
|
||||
}
|
||||
|
||||
if bErr := parseBearerJOSEHeader(token); bErr != nil {
|
||||
t.bearerFailureTracker.recordFailure(ip)
|
||||
t.writeBearerError(rw, req, bErr)
|
||||
return
|
||||
}
|
||||
|
||||
p, bErr := t.buildPrincipalFromBearerToken(token)
|
||||
if bErr != nil {
|
||||
t.bearerFailureTracker.recordFailure(ip)
|
||||
t.writeBearerError(rw, req, bErr)
|
||||
return
|
||||
}
|
||||
|
||||
t.bearerFailureTracker.recordSuccess(ip)
|
||||
if t.logger != nil {
|
||||
t.logger.Debugf("bearer auth success: identifier_hash=%s path=%s",
|
||||
hashIdentifierForLog(p.Identifier), req.URL.Path)
|
||||
}
|
||||
t.forwardAuthorized(rw, req, p)
|
||||
}
|
||||
|
||||
// buildPrincipalFromBearerToken runs the full bearer verification pipeline
|
||||
// described in spec §7.3 and returns a principal ready for forwardAuthorized.
|
||||
// Returns a typed *bearerError on failure so the caller can map to status.
|
||||
func (t *TraefikOidc) buildPrincipalFromBearerToken(token string) (*principal, *bearerError) {
|
||||
if err := t.verifyTokenWithOpts(token, verifyOpts{skipReplayMarking: true}); err != nil {
|
||||
return nil, newBearerError(bearerErrInvalidToken, "token verification failed: "+err.Error())
|
||||
}
|
||||
|
||||
parsed, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return nil, newBearerError(bearerErrInvalidToken, "post-verify parseJWT failed: "+err.Error())
|
||||
}
|
||||
claims := parsed.Claims
|
||||
|
||||
// Token-type guard. Reuse the well-tested classifier which already
|
||||
// checks nonce / typ=at+jwt / token_use / scope / aud-vs-clientID.
|
||||
if t.detectTokenType(parsed, token) {
|
||||
return nil, newBearerError(bearerErrInvalidToken, "ID tokens are not accepted on the bearer path")
|
||||
}
|
||||
// Belt-and-braces explicit rejection (cheap, catches edge cases not
|
||||
// covered by detectTokenType's heuristic).
|
||||
if nonce, ok := claims["nonce"].(string); ok && nonce != "" {
|
||||
return nil, newBearerError(bearerErrInvalidToken, "nonce claim present (ID-token shape)")
|
||||
}
|
||||
if tu, ok := claims["token_use"].(string); ok && tu == "id" {
|
||||
return nil, newBearerError(bearerErrInvalidToken, "token_use=id rejected")
|
||||
}
|
||||
|
||||
if bErr := enforceMultiAudienceAzp(claims, t.clientID); bErr != nil {
|
||||
return nil, bErr
|
||||
}
|
||||
if bErr := enforceIatAge(claims, t.maxTokenAge); bErr != nil {
|
||||
return nil, bErr
|
||||
}
|
||||
|
||||
if t.requireTokenIntrospection {
|
||||
if bErr := t.introspectOnBearerPath(token); bErr != nil {
|
||||
return nil, bErr
|
||||
}
|
||||
}
|
||||
|
||||
rawIdentifier, bErr := resolveBearerIdentifier(claims, t.bearerIdentifierClaim)
|
||||
if bErr != nil {
|
||||
return nil, bErr
|
||||
}
|
||||
identifier, bErr := sanitizeBearerIdentifier(rawIdentifier, t.maxIdentifierLength)
|
||||
if bErr != nil {
|
||||
return nil, bErr
|
||||
}
|
||||
|
||||
subject, _ := claims["sub"].(string)
|
||||
clientID, _ := claims["azp"].(string)
|
||||
if clientID == "" {
|
||||
clientID, _ = claims["client_id"].(string)
|
||||
}
|
||||
|
||||
return &principal{
|
||||
Source: sourceBearer,
|
||||
Identifier: identifier,
|
||||
Subject: subject,
|
||||
ClientID: clientID,
|
||||
Claims: claims,
|
||||
AccessToken: token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// introspectOnBearerPath calls the existing RFC 7662 introspector when the
|
||||
// operator demands real-time revocation. Distinguishes "token revoked" (401)
|
||||
// from "endpoint unavailable" (503) so transient infra failures don't look
|
||||
// like credential failures.
|
||||
func (t *TraefikOidc) introspectOnBearerPath(token string) *bearerError {
|
||||
resp, err := t.introspectToken(token)
|
||||
if err != nil {
|
||||
return newBearerError(bearerErrIntrospectionUnavailable, "introspection failed: "+err.Error())
|
||||
}
|
||||
if !resp.Active {
|
||||
return newBearerError(bearerErrTokenInactive, "introspection reports token inactive")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,830 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Helper builders
|
||||
// =============================================================================
|
||||
|
||||
// makeBearerJWT constructs a JWT with explicit header + claims for tests.
|
||||
// Signature is opaque (b64("signature")) — bearer tests don't exercise the
|
||||
// real cryptographic verifier; verification is bypassed via tokenCache pre-
|
||||
// seed so the bearer pipeline under test sees a "verified" token.
|
||||
func makeBearerJWT(t *testing.T, header, claims map[string]interface{}) string {
|
||||
t.Helper()
|
||||
hb, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal header: %v", err)
|
||||
}
|
||||
cb, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal claims: %v", err)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s.%s",
|
||||
base64.RawURLEncoding.EncodeToString(hb),
|
||||
base64.RawURLEncoding.EncodeToString(cb),
|
||||
base64.RawURLEncoding.EncodeToString([]byte("signature")),
|
||||
)
|
||||
}
|
||||
|
||||
// defaultBearerHeader produces the standard RS256+kid header used in tests.
|
||||
func defaultBearerHeader() map[string]interface{} {
|
||||
return map[string]interface{}{"alg": "RS256", "kid": "test-kid"}
|
||||
}
|
||||
|
||||
// defaultBearerClaims produces a baseline access-token claim set. Tests
|
||||
// shallow-clone and override fields as needed.
|
||||
func defaultBearerClaims() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"iss": "https://issuer.example.com",
|
||||
"aud": "https://api.example.com",
|
||||
"sub": "service-account-1",
|
||||
"scope": "api:read api:write",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
}
|
||||
}
|
||||
|
||||
// makeBearerOIDC constructs a TraefikOidc wired for bearer auth tests. The
|
||||
// real verifyTokenWithOpts pipeline is short-circuited via tokenCache pre-
|
||||
// seed: any token Set into t.tokenCache returns nil from VerifyToken,
|
||||
// letting tests exercise the post-verify bearer logic (classifier, identifier,
|
||||
// throttle, header forwarding) without standing up JWKs.
|
||||
func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
|
||||
t.Helper()
|
||||
sm := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("error"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://issuer.example.com",
|
||||
audience: "https://api.example.com",
|
||||
clientID: "https://api.example.com",
|
||||
tokenCache: NewTokenCache(),
|
||||
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
|
||||
ctx: context.Background(),
|
||||
enableBearerAuth: true,
|
||||
stripAuthorizationHeader: true,
|
||||
bearerEmitWWWAuthenticate: true,
|
||||
bearerOverridesCookie: false,
|
||||
bearerIdentifierClaim: "sub",
|
||||
maxIdentifierLength: 256,
|
||||
maxTokenAge: 24 * time.Hour,
|
||||
bearerFailureThreshold: 20,
|
||||
bearerFailureWindow: 60 * time.Second,
|
||||
bearerFailurePenalty: 60 * time.Second,
|
||||
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
|
||||
}
|
||||
oidc.extractClaimsFunc = extractClaims
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
|
||||
// seedVerified pre-populates the tokenCache so verifyTokenWithOpts short-
|
||||
// circuits to nil for the given token. Mirrors the production fast-return
|
||||
// path at token_manager.go for previously-verified tokens.
|
||||
func seedVerified(t *testing.T, oidc *TraefikOidc, token string, claims map[string]interface{}) {
|
||||
t.Helper()
|
||||
if oidc.tokenCache == nil {
|
||||
oidc.tokenCache = NewTokenCache()
|
||||
}
|
||||
oidc.tokenCache.Set(token, claims, time.Hour)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Unit tests — small helpers
|
||||
// =============================================================================
|
||||
|
||||
func TestDetectBearerToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
header string
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{"missing header", "", "", false},
|
||||
{"basic auth", "Basic abc", "", false},
|
||||
{"bearer with token", "Bearer abc.def.ghi", "abc.def.ghi", true},
|
||||
{"lowercase bearer", "bearer abc.def.ghi", "abc.def.ghi", true},
|
||||
{"mixed case", "BeArEr abc.def.ghi", "abc.def.ghi", true},
|
||||
{"empty token after prefix", "Bearer ", "", false},
|
||||
{"bearer no space", "Bearerabc", "", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tc.header != "" {
|
||||
req.Header.Set("Authorization", tc.header)
|
||||
}
|
||||
got, ok := detectBearerToken(req)
|
||||
if ok != tc.ok || got != tc.want {
|
||||
t.Fatalf("got=(%q, %v), want=(%q, %v)", got, ok, tc.want, tc.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBearerJOSEHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
mk := func(t *testing.T, h map[string]interface{}) string {
|
||||
return makeBearerJWT(t, h, map[string]interface{}{"sub": "x"})
|
||||
}
|
||||
cases := []struct {
|
||||
header map[string]interface{}
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "valid RS256", header: map[string]interface{}{"alg": "RS256", "kid": "k1"}, wantErr: false},
|
||||
{name: "valid ES512", header: map[string]interface{}{"alg": "ES512", "kid": "abc-_.="}, wantErr: false},
|
||||
{name: "alg=none rejected", header: map[string]interface{}{"alg": "none", "kid": "k1"}, wantErr: true},
|
||||
{name: "alg=HS256 rejected", header: map[string]interface{}{"alg": "HS256", "kid": "k1"}, wantErr: true},
|
||||
{name: "missing kid", header: map[string]interface{}{"alg": "RS256"}, wantErr: true},
|
||||
{name: "kid too long", header: map[string]interface{}{"alg": "RS256", "kid": strings.Repeat("a", bearerKidMaxLen+1)}, wantErr: true},
|
||||
{name: "kid bad chars", header: map[string]interface{}{"alg": "RS256", "kid": "evil/../etc/passwd"}, wantErr: true},
|
||||
{name: "kid with space", header: map[string]interface{}{"alg": "RS256", "kid": "key one"}, wantErr: true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token := mk(t, tc.header)
|
||||
err := parseBearerJOSEHeader(token)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitiseBearerIdentifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"normal sub", "service-account-1", "service-account-1", false},
|
||||
{"email-like", "alice@example.com", "alice@example.com", false},
|
||||
{"trim whitespace", " abc ", "abc", false},
|
||||
{"empty", "", "", true},
|
||||
{"only whitespace", " ", "", true},
|
||||
{"control char (newline)", "alice\nbob", "", true},
|
||||
{"control char (CR)", "alice\rbob", "", true},
|
||||
{"control char (NUL)", "alice\x00bob", "", true},
|
||||
{"bidi override", "alice\u202ebob", "", true},
|
||||
{"bidi isolate", "alice\u2066bob", "", true},
|
||||
{"comma delimiter", "alice,bob", "", true},
|
||||
{"semicolon delimiter", "alice;bob", "", true},
|
||||
{"equals delimiter", "alice=bob", "", true},
|
||||
{"over length", strings.Repeat("a", 257), "", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := sanitizeBearerIdentifier(tc.in, 256)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
if !tc.wantErr && got != tc.want {
|
||||
t.Fatalf("got=%q want=%q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBearerIdentifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
claims map[string]interface{}
|
||||
name string
|
||||
claim string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "default sub", claims: map[string]interface{}{"sub": "abc"}, claim: "", want: "abc"},
|
||||
{name: "explicit sub", claims: map[string]interface{}{"sub": "abc"}, claim: "sub", want: "abc"},
|
||||
{name: "custom client_id claim", claims: map[string]interface{}{"client_id": "svc"}, claim: "client_id", want: "svc"},
|
||||
{name: "missing claim", claims: map[string]interface{}{"other": "x"}, claim: "sub", wantErr: true},
|
||||
{name: "non-string claim", claims: map[string]interface{}{"sub": 123}, claim: "sub", wantErr: true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := resolveBearerIdentifier(tc.claims, tc.claim)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
if !tc.wantErr && got != tc.want {
|
||||
t.Fatalf("got=%q want=%q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceMultiAudienceAzp(t *testing.T) {
|
||||
t.Parallel()
|
||||
const cid = "https://api.example.com"
|
||||
cases := []struct {
|
||||
claims map[string]interface{}
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "single string aud", claims: map[string]interface{}{"aud": "x"}, wantErr: false},
|
||||
{name: "single element array", claims: map[string]interface{}{"aud": []interface{}{"x"}}, wantErr: false},
|
||||
{name: "multi-aud with matching azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": cid}, wantErr: false},
|
||||
{name: "multi-aud missing azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}}, wantErr: true},
|
||||
{name: "multi-aud empty azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": ""}, wantErr: true},
|
||||
{name: "multi-aud wrong azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": "other"}, wantErr: true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := enforceMultiAudienceAzp(tc.claims, cid)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceIatAge(t *testing.T) {
|
||||
t.Parallel()
|
||||
now := time.Now()
|
||||
cases := []struct {
|
||||
name string
|
||||
iat float64
|
||||
maxAge time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "fresh", iat: float64(now.Unix()), maxAge: time.Hour, wantErr: false},
|
||||
{name: "23h59m old, max 24h", iat: float64(now.Add(-23*time.Hour - 59*time.Minute).Unix()), maxAge: 24 * time.Hour, wantErr: false},
|
||||
{name: "25h old, max 24h", iat: float64(now.Add(-25 * time.Hour).Unix()), maxAge: 24 * time.Hour, wantErr: true},
|
||||
{name: "1970 token", iat: float64(0), maxAge: 24 * time.Hour, wantErr: true},
|
||||
{name: "maxAge disabled (0)", iat: float64(0), maxAge: 0, wantErr: false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := enforceIatAge(map[string]interface{}{"iat": tc.iat}, tc.maxAge)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBearerFailureTracker(t *testing.T) {
|
||||
t.Parallel()
|
||||
tr := newBearerFailureTracker(3, 60*time.Second, 60*time.Second)
|
||||
const ip = "10.0.0.1"
|
||||
// Below threshold: not blocked.
|
||||
for i := 0; i < 2; i++ {
|
||||
tr.recordFailure(ip)
|
||||
if b, _ := tr.blocked(ip); b {
|
||||
t.Fatalf("blocked too early after %d failures", i+1)
|
||||
}
|
||||
}
|
||||
// Threshold reached: blocked.
|
||||
tr.recordFailure(ip)
|
||||
if b, retry := tr.blocked(ip); !b || retry <= 0 {
|
||||
t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry)
|
||||
}
|
||||
// A success while a penalty is active must NOT wipe the in-effect lockout
|
||||
// (otherwise a single success could clear an attacker's penalty).
|
||||
tr.recordSuccess(ip)
|
||||
if b, _ := tr.blocked(ip); !b {
|
||||
t.Fatalf("expected still blocked after success while penalty active")
|
||||
}
|
||||
// Other IPs are unaffected.
|
||||
if b, _ := tr.blocked("10.0.0.2"); b {
|
||||
t.Fatalf("unrelated IP should not be blocked")
|
||||
}
|
||||
|
||||
// With an expired penalty, a success resets the counter so a subsequent
|
||||
// sub-threshold failure does not immediately re-block.
|
||||
tr2 := newBearerFailureTracker(3, 60*time.Second, 1*time.Millisecond)
|
||||
const ip2 = "10.0.0.3"
|
||||
for i := 0; i < 3; i++ {
|
||||
tr2.recordFailure(ip2)
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond) // let the short penalty expire
|
||||
if b, _ := tr2.blocked(ip2); b {
|
||||
t.Fatalf("expected unblocked after penalty expiry")
|
||||
}
|
||||
tr2.recordSuccess(ip2) // resets count since penalty has passed
|
||||
tr2.recordFailure(ip2) // single failure, well below threshold
|
||||
if b, _ := tr2.blocked(ip2); b {
|
||||
t.Fatalf("expected unblocked: counter should have reset after success")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Integration tests — full ServeHTTP via the bearer pipeline
|
||||
// =============================================================================
|
||||
|
||||
func TestServeHTTP_Bearer_HappyPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
var nextCalled atomic.Bool
|
||||
var capturedHeaders http.Header
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled.Store(true)
|
||||
capturedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
claims := defaultBearerClaims()
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled.Load() {
|
||||
t.Fatalf("expected next handler to run; got status=%d body=%q", rw.Code, rw.Body.String())
|
||||
}
|
||||
if rw.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", rw.Code)
|
||||
}
|
||||
if got := capturedHeaders.Get("X-Forwarded-User"); got != "service-account-1" {
|
||||
t.Fatalf("X-Forwarded-User=%q, want service-account-1", got)
|
||||
}
|
||||
if got := capturedHeaders.Get("Authorization"); got != "" {
|
||||
t.Fatalf("Authorization should be stripped, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_StripAuthDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
var capturedAuth string
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
oidc.stripAuthorizationHeader = false
|
||||
claims := defaultBearerClaims()
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !strings.HasPrefix(capturedAuth, "Bearer ") {
|
||||
t.Fatalf("expected Authorization to be forwarded, got=%q", capturedAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_RejectIDToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run for ID token rejection")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
// ID-token shape: nonce claim present and no scope. detectTokenType
|
||||
// returns true.
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://issuer.example.com",
|
||||
"aud": "https://api.example.com",
|
||||
"sub": "user-1",
|
||||
"nonce": "n-0S6_WzA2Mj",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
}
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", rw.Code)
|
||||
}
|
||||
if wa := rw.Header().Get("WWW-Authenticate"); !strings.Contains(wa, `error="invalid_token"`) {
|
||||
t.Fatalf("expected WWW-Authenticate invalid_token, got=%q", wa)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_AlgNoneRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run for alg=none")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
header := map[string]interface{}{"alg": "none", "kid": "k1"}
|
||||
claims := defaultBearerClaims()
|
||||
token := makeBearerJWT(t, header, claims)
|
||||
// Even if we pre-seeded the cache, the early alg pin runs FIRST.
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_KidTooLongRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run for oversized kid")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
header := map[string]interface{}{"alg": "RS256", "kid": strings.Repeat("a", bearerKidMaxLen+1)}
|
||||
claims := defaultBearerClaims()
|
||||
token := makeBearerJWT(t, header, claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_MultiAudRequiresAzp(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run for multi-aud without azp")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
claims := defaultBearerClaims()
|
||||
claims["aud"] = []interface{}{"https://api.example.com", "https://other.example.com"}
|
||||
delete(claims, "azp")
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_MultiAudWithAzpAccepted(t *testing.T) {
|
||||
t.Parallel()
|
||||
var nextCalled atomic.Bool
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled.Store(true)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
claims := defaultBearerClaims()
|
||||
claims["aud"] = []interface{}{"https://api.example.com", "https://other.example.com"}
|
||||
claims["azp"] = oidc.clientID
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusOK || !nextCalled.Load() {
|
||||
t.Fatalf("expected 200 + next called; got status=%d called=%v", rw.Code, nextCalled.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_IatTooOldRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run for old iat")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
claims := defaultBearerClaims()
|
||||
claims["iat"] = float64(time.Now().Add(-25 * time.Hour).Unix())
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_IdentifierWithBidiRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run for bidi identifier")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
claims := defaultBearerClaims()
|
||||
claims["sub"] = "alice\u202ebob"
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_ReplayRegression(t *testing.T) {
|
||||
t.Parallel()
|
||||
var successCount atomic.Int32
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
successCount.Add(1)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
claims := defaultBearerClaims()
|
||||
claims["jti"] = "regression-jti"
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code != http.StatusOK {
|
||||
t.Fatalf("iteration %d: status=%d, want 200", i, rw.Code)
|
||||
}
|
||||
}
|
||||
if successCount.Load() != 100 {
|
||||
t.Fatalf("successCount=%d, want 100", successCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_ThrottleTrips429(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run during throttle test")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
oidc.bearerFailureTracker = newBearerFailureTracker(3, 60*time.Second, 60*time.Second)
|
||||
|
||||
// Send malformed bearers from the same RemoteAddr until threshold trips.
|
||||
send := func() *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.RemoteAddr = "10.0.0.5:1234"
|
||||
req.Header.Set("Authorization", "Bearer not-a-jwt")
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
return rw
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
rw := send()
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("pre-throttle iteration %d: status=%d, want 401", i, rw.Code)
|
||||
}
|
||||
}
|
||||
// 4th request: throttled.
|
||||
rw := send()
|
||||
if rw.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 after threshold, got %d", rw.Code)
|
||||
}
|
||||
if ra := rw.Header().Get("Retry-After"); ra == "" {
|
||||
t.Fatalf("expected Retry-After header on 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_ExcludedURLStripsAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
var capturedAuth string
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
oidc.excludedURLs = map[string]struct{}{"/favicon.ico": {}}
|
||||
|
||||
req := httptest.NewRequest("GET", "/favicon.ico", nil)
|
||||
req.Header.Set("Authorization", "Bearer abc.def.ghi")
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusOK {
|
||||
t.Fatalf("excluded path should pass; got %d", rw.Code)
|
||||
}
|
||||
if capturedAuth != "" {
|
||||
t.Fatalf("Authorization must be stripped on excluded paths, got=%q", capturedAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_RolesGate(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
rolesClaim []interface{}
|
||||
want int
|
||||
}{
|
||||
{name: "matching role", rolesClaim: []interface{}{"admin"}, want: http.StatusOK},
|
||||
{name: "no matching role", rolesClaim: []interface{}{"viewer"}, want: http.StatusForbidden},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
oidc.allowedRolesAndGroups = map[string]struct{}{"admin": {}}
|
||||
oidc.roleClaimName = "roles"
|
||||
claims := defaultBearerClaims()
|
||||
claims["roles"] = tc.rolesClaim
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code != tc.want {
|
||||
t.Fatalf("status=%d, want %d", rw.Code, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_CookieWinsByDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Both cookie and bearer present: cookie path runs (which will redirect
|
||||
// to /authorize since the cookie is empty/unauthenticated).
|
||||
var nextCalled atomic.Bool
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled.Store(true)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
claims := defaultBearerClaims()
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
prefix := oidc.sessionManager.GetCookiePrefix()
|
||||
req.AddCookie(&http.Cookie{Name: prefix + "main", Value: "irrelevant"})
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
// Cookie path consumed the request; bearer was ignored. Since the
|
||||
// cookie is empty, the cookie path will either 302 to /authorize or
|
||||
// return 401 — in either case, next must NOT be called.
|
||||
if nextCalled.Load() {
|
||||
t.Fatalf("next must not be called when bearer is ignored due to cookie precedence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_BearerOverridesCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
var nextCalled atomic.Bool
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled.Store(true)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
oidc.bearerOverridesCookie = true
|
||||
claims := defaultBearerClaims()
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
prefix := oidc.sessionManager.GetCookiePrefix()
|
||||
req.AddCookie(&http.Cookie{Name: prefix + "main", Value: "irrelevant"})
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled.Load() || rw.Code != http.StatusOK {
|
||||
t.Fatalf("expected bearer to win with override; status=%d called=%v", rw.Code, nextCalled.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_OversizedToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run for oversized token")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
huge := strings.Repeat("a", AccessTokenConfig.MaxLength+1)
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+huge)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_MalformedJWT(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("next must not run for malformed JWT")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer not.jwt") // 1 dot
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTP_Bearer_FeatureOffPassesThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Should not be reached: cookie path runs and (with no session)
|
||||
// will redirect or 401. We assert no panic / next not called.
|
||||
t.Fatalf("next must not run when bearer is off and no valid session exists")
|
||||
})
|
||||
oidc := makeBearerOIDC(t, next)
|
||||
oidc.enableBearerAuth = false
|
||||
claims := defaultBearerClaims()
|
||||
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
||||
seedVerified(t, oidc, token, claims)
|
||||
req := httptest.NewRequest("GET", "/api/work", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
// Expect non-200: either 302 to /authorize or 401. The point is the
|
||||
// bearer pipeline didn't run.
|
||||
if rw.Code == http.StatusOK {
|
||||
t.Fatalf("expected non-200 when bearer is off; got %d", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Startup validation tests
|
||||
// =============================================================================
|
||||
|
||||
func TestStartupValidation_BearerRequiresAudience(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://issuer.example.com"
|
||||
cfg.ClientID = "id"
|
||||
cfg.ClientSecret = "secret"
|
||||
cfg.CallbackURL = "/oauth/callback"
|
||||
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
cfg.EnableBearerAuth = true
|
||||
cfg.Audience = ""
|
||||
_, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), cfg, "bearer-test")
|
||||
if err == nil || !strings.Contains(err.Error(), "requires Audience") {
|
||||
t.Fatalf("expected audience-required error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupValidation_BearerRejectsEmailIdentifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://issuer.example.com"
|
||||
cfg.ClientID = "id"
|
||||
cfg.ClientSecret = "secret"
|
||||
cfg.CallbackURL = "/oauth/callback"
|
||||
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
cfg.EnableBearerAuth = true
|
||||
cfg.Audience = "https://api.example.com"
|
||||
cfg.BearerIdentifierClaim = "email"
|
||||
_, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), cfg, "bearer-test")
|
||||
if err == nil || !strings.Contains(err.Error(), "bearerIdentifierClaim=\"email\"") {
|
||||
t.Fatalf("expected email-identifier rejection, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Principal invariants
|
||||
// =============================================================================
|
||||
|
||||
func TestBuildPrincipalFromSession_NoIdentifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
oidc := &TraefikOidc{logger: NewLogger("error")}
|
||||
if p := oidc.buildPrincipalFromSession(nil); p != nil {
|
||||
t.Fatalf("nil session must produce nil principal")
|
||||
}
|
||||
}
|
||||
+137
@@ -0,0 +1,137 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// testCertPEM returns a valid PEM-encoded certificate harvested from an
|
||||
// httptest.NewTLSServer. Using httptest keeps the test free of any
|
||||
// handwritten static cert that could expire.
|
||||
func testCertPEM(t *testing.T) string {
|
||||
t.Helper()
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
cert := srv.Certificate()
|
||||
if cert == nil {
|
||||
t.Fatal("httptest.NewTLSServer did not expose a certificate")
|
||||
}
|
||||
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_Empty(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pool != nil {
|
||||
t.Errorf("expected nil pool when no CA source configured, got %v", pool)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_InlinePEM(t *testing.T) {
|
||||
cfg := &Config{CACertPEM: testCertPEM(t)}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Fatal("expected non-nil pool for valid CACertPEM")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_InlinePEM_Garbage(t *testing.T) {
|
||||
cfg := &Config{CACertPEM: "not a pem"}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for garbage CACertPEM, got nil")
|
||||
}
|
||||
if pool != nil {
|
||||
t.Errorf("expected nil pool on error, got %v", pool)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "caCertPEM") {
|
||||
t.Errorf("error should name the failing field, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_FilePath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ca.pem")
|
||||
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
|
||||
t.Fatalf("writing temp PEM: %v", err)
|
||||
}
|
||||
|
||||
cfg := &Config{CACertPath: path}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Fatal("expected non-nil pool for valid CACertPath")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_FilePath_Missing(t *testing.T) {
|
||||
cfg := &Config{CACertPath: "/does/not/exist/ca.pem"}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing CACertPath, got nil")
|
||||
}
|
||||
if pool != nil {
|
||||
t.Errorf("expected nil pool on error, got %v", pool)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_Combined(t *testing.T) {
|
||||
// Both inline and file sources populated — certificates from both should
|
||||
// be accepted into the same pool.
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ca.pem")
|
||||
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
|
||||
t.Fatalf("writing temp PEM: %v", err)
|
||||
}
|
||||
|
||||
cfg := &Config{CACertPath: path, CACertPEM: testCertPEM(t)}
|
||||
pool, err := cfg.loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Fatal("expected non-nil pool when both sources set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedTransportPool_ConfigKeyDistinguishesCAAndSkipVerify(t *testing.T) {
|
||||
p := GetGlobalTransportPool()
|
||||
cfgSystem := DefaultHTTPClientConfig()
|
||||
|
||||
cfgSkip := DefaultHTTPClientConfig()
|
||||
cfgSkip.InsecureSkipVerify = true
|
||||
|
||||
cfgCustomCA := DefaultHTTPClientConfig()
|
||||
pool, err := (&Config{CACertPEM: testCertPEM(t)}).loadCACertPool()
|
||||
if err != nil {
|
||||
t.Fatalf("loadCACertPool: %v", err)
|
||||
}
|
||||
cfgCustomCA.RootCAs = pool
|
||||
|
||||
keys := map[string]string{
|
||||
"system": p.configKey(cfgSystem),
|
||||
"skip": p.configKey(cfgSkip),
|
||||
"customCA": p.configKey(cfgCustomCA),
|
||||
}
|
||||
seen := make(map[string]string, len(keys))
|
||||
for name, key := range keys {
|
||||
if dup, ok := seen[key]; ok {
|
||||
t.Errorf("configKey collision: %s and %s share key %q", name, dup, key)
|
||||
}
|
||||
seen[key] = name
|
||||
}
|
||||
}
|
||||
+34
-4
@@ -16,19 +16,23 @@ type CacheManager struct {
|
||||
}
|
||||
|
||||
var (
|
||||
globalCacheManagerInstance *CacheManager
|
||||
cacheManagerInitOnce sync.Once
|
||||
globalCacheManagerInstance *CacheManager
|
||||
cacheManagerInitOnce sync.Once
|
||||
cacheManagerActiveFingerprint string
|
||||
)
|
||||
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance
|
||||
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance.
|
||||
//
|
||||
// Deprecated: Use GetGlobalCacheManagerWithConfig instead.
|
||||
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
return GetGlobalCacheManagerWithConfig(wg, nil)
|
||||
}
|
||||
|
||||
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
|
||||
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
|
||||
fp := redisFingerprint(config)
|
||||
cacheManagerInitOnce.Do(func() {
|
||||
cacheManagerActiveFingerprint = fp
|
||||
var redisConfig *RedisConfig
|
||||
var logger *Logger
|
||||
|
||||
@@ -54,9 +58,27 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
|
||||
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
|
||||
}
|
||||
})
|
||||
// Warn loudly if a later instance asks for a DIFFERENT explicit Redis
|
||||
// backend than the one that won initialization: the cache manager is a
|
||||
// process-global singleton shared across plugin instances (yaegi), so this
|
||||
// instance's divergent configuration is silently ignored, which would
|
||||
// otherwise collapse cache/state isolation between routes (rank 9).
|
||||
if fp != "" && cacheManagerActiveFingerprint != "" && fp != cacheManagerActiveFingerprint {
|
||||
NewLogger(config.LogLevel).Errorf("cache manager already initialized with Redis backend %q; this instance's Redis backend %q is IGNORED (process-global singleton). Use a single consistent cache configuration across all routes.", cacheManagerActiveFingerprint, fp)
|
||||
}
|
||||
return globalCacheManagerInstance
|
||||
}
|
||||
|
||||
// redisFingerprint returns a stable identifier for an explicitly-enabled Redis
|
||||
// backend (address + key prefix), or "" when Redis is not explicitly enabled.
|
||||
// Used to detect divergent cache configurations across plugin instances.
|
||||
func redisFingerprint(config *Config) string {
|
||||
if config == nil || config.Redis == nil || !config.Redis.Enabled {
|
||||
return ""
|
||||
}
|
||||
return config.Redis.Address + "|" + config.Redis.KeyPrefix
|
||||
}
|
||||
|
||||
// GetSharedTokenBlacklist returns the shared token blacklist cache
|
||||
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
@@ -112,6 +134,14 @@ func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
|
||||
// by the refresh path to coalesce grants across Traefik replicas via Redis.
|
||||
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
func (cm *CacheManager) Close() error {
|
||||
cm.mu.Lock()
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// isSupportedClientAssertionAlg reports whether alg is a recognized JWS
|
||||
// algorithm for private_key_jwt (RFC 7523 §2.2).
|
||||
func isSupportedClientAssertionAlg(alg string) bool {
|
||||
switch alg {
|
||||
case "RS256", "RS384", "RS512",
|
||||
"PS256", "PS384", "PS512",
|
||||
"ES256", "ES384", "ES512":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ClientAssertionSigner builds and signs client_assertion JWTs (RFC 7523 §2.2).
|
||||
type ClientAssertionSigner struct {
|
||||
key crypto.PrivateKey
|
||||
alg string
|
||||
kid string
|
||||
// rand is the entropy source for jti generation and PSS/ECDSA signing.
|
||||
// Defaults to crypto/rand.Reader when nil.
|
||||
rand io.Reader
|
||||
// now returns the current time. Defaults to time.Now when nil.
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// NewClientAssertionSigner parses pemBytes as a private key, validates that
|
||||
// alg is consistent with the key type, and returns a ready-to-use signer.
|
||||
// kid is placed verbatim in the JWS header.
|
||||
//
|
||||
// PEM block types understood:
|
||||
// - "PRIVATE KEY" → PKCS#8 (tried first for all types)
|
||||
// - "RSA PRIVATE KEY" → PKCS#1
|
||||
// - "EC PRIVATE KEY" → SEC1
|
||||
func NewClientAssertionSigner(pemBytes []byte, alg, kid string) (*ClientAssertionSigner, error) {
|
||||
if !isSupportedClientAssertionAlg(alg) {
|
||||
return nil, fmt.Errorf("unsupported client assertion alg %q", alg)
|
||||
}
|
||||
if kid == "" {
|
||||
return nil, fmt.Errorf("kid must not be empty")
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM block found in private key material")
|
||||
}
|
||||
|
||||
var key crypto.PrivateKey
|
||||
var parseErr error
|
||||
|
||||
switch block.Type {
|
||||
case "PRIVATE KEY":
|
||||
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
case "RSA PRIVATE KEY":
|
||||
key, parseErr = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "EC PRIVATE KEY":
|
||||
key, parseErr = x509.ParseECPrivateKey(block.Bytes)
|
||||
default:
|
||||
// Best-effort fallback for unknown block types.
|
||||
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
}
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key (block type %q): %w", block.Type, parseErr)
|
||||
}
|
||||
|
||||
if err := validateAlgKeyMatch(alg, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ClientAssertionSigner{key: key, alg: alg, kid: kid}, nil
|
||||
}
|
||||
|
||||
// validateAlgKeyMatch returns an error when alg implies a key type that does
|
||||
// not match the actual key.
|
||||
func validateAlgKeyMatch(alg string, key crypto.PrivateKey) error {
|
||||
switch alg[0] {
|
||||
case 'R', 'P': // RS* or PS*
|
||||
if _, ok := key.(*rsa.PrivateKey); !ok {
|
||||
return fmt.Errorf("alg %q requires an RSA key, got %T", alg, key)
|
||||
}
|
||||
case 'E': // ES*
|
||||
if _, ok := key.(*ecdsa.PrivateKey); !ok {
|
||||
return fmt.Errorf("alg %q requires an EC key, got %T", alg, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign constructs and returns a signed client_assertion JWT.
|
||||
// audience is typically the token endpoint URL (RFC 7523 §3).
|
||||
// clientID is used as both iss and sub per RFC 7523 §2.2.
|
||||
func (s *ClientAssertionSigner) Sign(audience, clientID string) (string, error) {
|
||||
rander := s.rand
|
||||
if rander == nil {
|
||||
rander = rand.Reader
|
||||
}
|
||||
nowFn := s.now
|
||||
if nowFn == nil {
|
||||
nowFn = time.Now
|
||||
}
|
||||
|
||||
now := nowFn()
|
||||
|
||||
// 16 random bytes as lowercase hex for jti uniqueness.
|
||||
jtiBytes := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rander, jtiBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate jti: %w", err)
|
||||
}
|
||||
jti := hex.EncodeToString(jtiBytes)
|
||||
|
||||
header := map[string]string{
|
||||
"alg": s.alg,
|
||||
"typ": "JWT",
|
||||
"kid": s.kid,
|
||||
}
|
||||
hdrJSON, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal JWT header: %w", err)
|
||||
}
|
||||
|
||||
claims := map[string]any{
|
||||
"iss": clientID,
|
||||
"sub": clientID,
|
||||
"aud": audience,
|
||||
"jti": jti,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(60 * time.Second).Unix(),
|
||||
}
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal JWT claims: %w", err)
|
||||
}
|
||||
|
||||
hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signingInput := hdrB64 + "." + claimsB64
|
||||
|
||||
sig, err := s.sign(rander, []byte(signingInput))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig), nil
|
||||
}
|
||||
|
||||
// sign computes raw signature bytes for signingInput per s.alg.
|
||||
// validateAlgKeyMatch in NewClientAssertionSigner guarantees the key type
|
||||
// matches s.alg, but the comma-ok asserts here keep errcheck happy and
|
||||
// surface internal misuse loudly instead of via panic.
|
||||
func (s *ClientAssertionSigner) sign(rander io.Reader, input []byte) ([]byte, error) {
|
||||
switch s.alg {
|
||||
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
|
||||
rsaKey, ok := s.key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal: alg %q requires *rsa.PrivateKey, got %T", s.alg, s.key)
|
||||
}
|
||||
hash := rsaHashForAlg(s.alg)
|
||||
digest := hashSum(hash, input)
|
||||
if s.alg[0] == 'R' {
|
||||
return signRSAPKCS1v15(rander, rsaKey, hash, digest)
|
||||
}
|
||||
return signRSAPSS(rander, rsaKey, hash, digest)
|
||||
case "ES256", "ES384", "ES512":
|
||||
ecKey, ok := s.key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal: alg %q requires *ecdsa.PrivateKey, got %T", s.alg, s.key)
|
||||
}
|
||||
hash := ecHashForAlg(s.alg)
|
||||
digest := hashSum(hash, input)
|
||||
return signECDSA(rander, ecKey, digest)
|
||||
}
|
||||
return nil, fmt.Errorf("unhandled alg %q", s.alg)
|
||||
}
|
||||
|
||||
func rsaHashForAlg(alg string) crypto.Hash {
|
||||
switch alg {
|
||||
case "RS256", "PS256":
|
||||
return crypto.SHA256
|
||||
case "RS384", "PS384":
|
||||
return crypto.SHA384
|
||||
case "RS512", "PS512":
|
||||
return crypto.SHA512
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func ecHashForAlg(alg string) crypto.Hash {
|
||||
switch alg {
|
||||
case "ES256":
|
||||
return crypto.SHA256
|
||||
case "ES384":
|
||||
return crypto.SHA384
|
||||
case "ES512":
|
||||
return crypto.SHA512
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func hashSum(h crypto.Hash, input []byte) []byte {
|
||||
switch h {
|
||||
case crypto.SHA256:
|
||||
sum := sha256.Sum256(input)
|
||||
return sum[:]
|
||||
case crypto.SHA384:
|
||||
sum := sha512.Sum384(input)
|
||||
return sum[:]
|
||||
case crypto.SHA512:
|
||||
sum := sha512.Sum512(input)
|
||||
return sum[:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func signRSAPKCS1v15(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
|
||||
sig, err := rsa.SignPKCS1v15(rander, key, hash, digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("RSA PKCS1v15 signing failed: %w", err)
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
func signRSAPSS(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
|
||||
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hash}
|
||||
sig, err := rsa.SignPSS(rander, key, hash, digest, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("RSA PSS signing failed: %w", err)
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// signECDSA produces the JWS raw r||s signature (RFC 7515 App. A.3).
|
||||
// Each scalar is zero-padded to (curve.BitSize+7)/8 bytes.
|
||||
func signECDSA(rander io.Reader, key *ecdsa.PrivateKey, digest []byte) ([]byte, error) {
|
||||
r, ss, err := ecdsa.Sign(rander, key, digest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ECDSA signing failed: %w", err)
|
||||
}
|
||||
byteLen := (key.Curve.Params().BitSize + 7) / 8
|
||||
sig := make([]byte, 2*byteLen)
|
||||
padBigInt(sig[0:byteLen], r)
|
||||
padBigInt(sig[byteLen:], ss)
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// padBigInt writes n as a fixed-width big-endian integer into buf.
|
||||
func padBigInt(buf []byte, n *big.Int) {
|
||||
b := n.Bytes()
|
||||
copy(buf[len(buf)-len(b):], b)
|
||||
}
|
||||
|
||||
// buildClientAssertionSignerFromConfig loads key material and constructs a
|
||||
// ClientAssertionSigner. Called from NewWithContext when
|
||||
// ClientAuthMethod == "private_key_jwt".
|
||||
func buildClientAssertionSignerFromConfig(config *Config) (*ClientAssertionSigner, error) {
|
||||
var pemBytes []byte
|
||||
|
||||
if config.ClientAssertionPrivateKey != "" {
|
||||
pemBytes = []byte(config.ClientAssertionPrivateKey)
|
||||
} else {
|
||||
data, err := os.ReadFile(config.ClientAssertionKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read clientAssertionKeyPath %q: %w", config.ClientAssertionKeyPath, err)
|
||||
}
|
||||
pemBytes = data
|
||||
}
|
||||
|
||||
alg := config.ClientAssertionAlg
|
||||
if alg == "" {
|
||||
alg = "RS256"
|
||||
}
|
||||
|
||||
return NewClientAssertionSigner(pemBytes, alg, config.ClientAssertionKeyID)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
//go:build ignore
|
||||
|
||||
// Command yaegicheck verifies that the traefikoidc plugin can be imported and
|
||||
// instantiated by the yaegi interpreter — the same way Traefik loads a plugin.
|
||||
//
|
||||
// It is run by `make yaegi-validate`. Importing the plugin package forces yaegi
|
||||
// to interpret every source file in the package (and its vendored
|
||||
// dependencies), so any construct yaegi cannot handle (unsupported stdlib
|
||||
// symbol, reflection edge case, etc.) surfaces here rather than at Traefik load
|
||||
// time. CreateConfig + New additionally exercise the instantiation path
|
||||
// (session manager, cookie codec, caches, key derivation) under the interpreter.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
oidc "github.com/lukaszraczylo/traefikoidc"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := oidc.CreateConfig()
|
||||
cfg.ProviderURL = "https://accounts.google.com"
|
||||
cfg.ClientID = "yaegi-check-client"
|
||||
cfg.ClientSecret = "yaegi-check-secret"
|
||||
cfg.CallbackURL = "/oauth2/callback"
|
||||
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef"
|
||||
cfg.RateLimit = 100
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
h, err := oidc.New(context.Background(), next, cfg, "yaegi-check")
|
||||
if err != nil {
|
||||
fmt.Println("FAIL: New returned an error under yaegi:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if h == nil {
|
||||
fmt.Println("FAIL: New returned a nil handler under yaegi")
|
||||
os.Exit(1)
|
||||
}
|
||||
if closer, ok := h.(interface{ Close() error }); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
fmt.Println("OK: traefikoidc imported + CreateConfig + New succeeded under yaegi")
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
// MarshalJSON implements custom JSON marshaling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
@@ -47,7 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
// MarshalYAML implements custom YAML marshaling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalYAML() (interface{}, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
|
||||
@@ -278,82 +278,6 @@ func TestHTTPClientProfiler_Methods_CoverageBoost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SECURITY MONITORING TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestSecurityMonitor_StopCleanupRoutine_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
config := SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 5,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 30,
|
||||
RapidFailureThreshold: 3,
|
||||
CleanupIntervalMinutes: 60,
|
||||
RetentionHours: 24,
|
||||
EnablePatternDetection: true,
|
||||
EnableDetailedLogging: false,
|
||||
LogSuspiciousOnly: false,
|
||||
}
|
||||
|
||||
sm := NewSecurityMonitor(config, logger)
|
||||
if sm == nil {
|
||||
t.Fatal("Expected non-nil SecurityMonitor")
|
||||
}
|
||||
|
||||
// Start cleanup routine first (lowercase method)
|
||||
sm.startCleanupRoutine()
|
||||
|
||||
// Give it a moment to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop cleanup routine (public method)
|
||||
sm.StopCleanupRoutine()
|
||||
|
||||
// Stop again should be safe
|
||||
sm.StopCleanupRoutine()
|
||||
}
|
||||
|
||||
func TestSecurityMonitor_MultipleHandlers_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
config := SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 5,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 30,
|
||||
RapidFailureThreshold: 3,
|
||||
CleanupIntervalMinutes: 60,
|
||||
RetentionHours: 24,
|
||||
}
|
||||
|
||||
sm := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Create handler
|
||||
handler := &LoggingSecurityEventHandler{logger: logger}
|
||||
|
||||
// Register handler using AddEventHandler
|
||||
sm.AddEventHandler(handler)
|
||||
|
||||
// Record a failure to trigger events
|
||||
sm.RecordAuthenticationFailure("192.168.1.100", "test-agent", "/test", "test_failure", nil)
|
||||
}
|
||||
|
||||
func TestLoggingSecurityEventHandler_HandleSecurityEvent_AllSeverities_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
handler := &LoggingSecurityEventHandler{logger: logger}
|
||||
|
||||
// Severity is a string in this implementation
|
||||
events := []SecurityEvent{
|
||||
{Type: "test", Severity: "low", Message: "low severity"},
|
||||
{Type: "test", Severity: "medium", Message: "medium severity"},
|
||||
{Type: "test", Severity: "high", Message: "high severity"},
|
||||
{Type: "test", Severity: "critical", Message: "critical severity"},
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
handler.HandleSecurityEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SESSION MANAGER TESTS
|
||||
// =============================================================================
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("old-access-token")
|
||||
session.SetRefreshToken("old-refresh-token")
|
||||
session.SetIDToken("old-id-token")
|
||||
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Now perform selective clearing (as done in the fix)
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// Set initial session data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-token")
|
||||
session.SetCSRF("existing-csrf")
|
||||
|
||||
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
|
||||
// NEW BEHAVIOR: Selective clearing
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
|
||||
@@ -25,7 +25,10 @@ The **audience** (`aud`) claim in a JWT identifies the intended recipient of the
|
||||
|
||||
### Why Does This Matter?
|
||||
|
||||
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
|
||||
Audience validation rejects access tokens whose `aud` claim does not match the
|
||||
expected audience, blocking the trivial form of token confusion where a token
|
||||
issued for API A is presented to API B. (Defence in depth — pair with
|
||||
short-lived tokens, rotation, and per-API client credentials.)
|
||||
|
||||
---
|
||||
|
||||
@@ -137,8 +140,8 @@ http:
|
||||
**Recommended:** `true` for production
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
|
||||
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
|
||||
- When `true`: On audience mismatch, the middleware does **not** silently fall back to ID-token validation. It tries to refresh the access token first; if no refresh token is present (or refresh fails), the user is re-authenticated.
|
||||
- When `false`: Logs warnings and falls back to ID-token validation (backward compatible).
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
@@ -349,7 +352,7 @@ When opaque tokens are detected:
|
||||
|
||||
**Cache behavior:**
|
||||
- Cache key: Token hash
|
||||
- TTL: 5 minutes or token expiry (whichever is shorter)
|
||||
- TTL: 5 minutes; if the token's `exp` is sooner, the cache entry expires at `exp` instead. Tokens without `exp` use the flat 5-minute TTL.
|
||||
- Reduces introspection requests for frequently used tokens
|
||||
|
||||
---
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
# Bearer Token (M2M) Authentication
|
||||
|
||||
Opt-in path that lets API clients present `Authorization: Bearer <jwt>` to
|
||||
authenticate without going through the cookie-based OIDC redirect flow.
|
||||
Designed for machine-to-machine (M2M) traffic — services calling other
|
||||
services with tokens minted by your OIDC provider.
|
||||
|
||||
The bearer path lives next to the cookie path: both go through the same
|
||||
post-auth pipeline (`forwardAuthorized`) that injects identity headers,
|
||||
checks `allowedRolesAndGroups`, applies security headers, and forwards to
|
||||
the backend. The only thing that differs is how the principal is established
|
||||
for that single request.
|
||||
|
||||
## Quick start
|
||||
|
||||
```yaml
|
||||
enableBearerAuth: true
|
||||
audience: https://api.example.com # REQUIRED when bearer is enabled
|
||||
clientID: my-api-client-id
|
||||
providerURL: https://issuer.example.com
|
||||
sessionEncryptionKey: <32+-byte secret>
|
||||
callbackURL: /oauth2/callback
|
||||
```
|
||||
|
||||
That is the minimum. Everything else has a secure default.
|
||||
|
||||
## Obtaining bearer tokens from your OIDC provider
|
||||
|
||||
The middleware only **validates** bearer tokens — minting them is the IdP's job. For M2M traffic the canonical mint flow is OAuth 2.0 **`client_credentials`** (RFC 6749 §4.4); some providers require **JWT bearer assertion** (RFC 7523) instead.
|
||||
|
||||
```
|
||||
┌────────────┐ POST /token ┌──────────┐
|
||||
│ client │ ───────────────────────────────►│ IdP │
|
||||
│ (service) │ grant_type=client_credentials │ /token │
|
||||
│ │ client_id=… │ │
|
||||
│ │ client_secret=… (or JWT) │ │
|
||||
│ │ audience=https://api.… ←── critical │
|
||||
│ │ scope=api:read … │
|
||||
│ │ ◄───────────────────────────────│ │
|
||||
│ │ access_token (JWT) │ │
|
||||
└────────────┘ └──────────┘
|
||||
│
|
||||
│ GET /protected
|
||||
│ Authorization: Bearer <access_token>
|
||||
▼
|
||||
Your service (behind Traefik + this plugin)
|
||||
```
|
||||
|
||||
The IdP returns a JWT signed by the same JWKs the middleware already trusts (it discovers them from `providerURL`/.well-known). On the first protected request, the middleware verifies signature + issuer + **audience** + `exp` + identifier claim, then forwards downstream with `X-Forwarded-User` set.
|
||||
|
||||
### Minimal worked example (Auth0-shape)
|
||||
|
||||
```bash
|
||||
# 1. Mint a token
|
||||
curl -s -X POST https://issuer.example.com/oauth/token \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": "your-m2m-client-id",
|
||||
"client_secret": "your-m2m-client-secret",
|
||||
"audience": "https://api.example.com",
|
||||
"scope": "api:read api:write"
|
||||
}'
|
||||
# → {"access_token":"eyJhbGciOiJSUzI1NiIs…","token_type":"Bearer","expires_in":86400,…}
|
||||
|
||||
# 2. Use it
|
||||
curl -H 'Authorization: Bearer eyJhbGciOiJSUzI1NiIs…' https://api.example.com/protected
|
||||
```
|
||||
|
||||
The `audience` field in the token request **must match** the `audience` you configured on the middleware. Mismatch → 401 with `Bearer error="invalid_token"`.
|
||||
|
||||
### Per-provider quick reference
|
||||
|
||||
| Provider | Grant | Token endpoint | Audience parameter | Notes |
|
||||
|---|---|---|---|---|
|
||||
| **Auth0** | `client_credentials` | `https://TENANT.auth0.com/oauth/token` | `audience=<your API identifier>` | Register an "API" + "Machine to Machine Application" authorised against that API. Without `audience` you get an opaque /userinfo token, which the bearer path rejects. See `docs/AUTH0_AUDIENCE_GUIDE.md`. |
|
||||
| **Okta** | `client_credentials` | `https://TENANT.okta.com/oauth2/default/v1/token` | Configured in the authorization server; default `aud` is the auth-server URL | Service app must enable the `client_credentials` flow and be granted the requested scopes. |
|
||||
| **Keycloak** | `client_credentials` | `https://kc/realms/REALM/protocol/openid-connect/token` | Configure an "Audience" mapper on a client scope, or use `client_id` as the audience | Client must have `serviceAccountsEnabled: true` plus role mappings. |
|
||||
| **Entra ID / Azure AD** | `client_credentials` (v2.0 endpoint) | `https://login.microsoftonline.com/TENANT/oauth2/v2.0/token` | Pass `scope=<App ID URI>/.default`; `aud` ends up being the API's App ID URI | Requires an App Registration + API permissions + admin consent. **Use the v2.0 endpoint** — v1 issues Microsoft-proprietary access tokens that are opaque to non-Microsoft clients. |
|
||||
| **AWS Cognito** | `client_credentials` | `https://YOUR_DOMAIN.auth.REGION.amazoncognito.com/oauth2/token` | Scopes from a "Resource Server" attached to your User Pool | App client must have `client_credentials` flow enabled. Use HTTP **Basic** auth header for `client_id:client_secret`. |
|
||||
| **GitLab** | `client_credentials` | `https://gitlab.com/oauth/token` | Audience matches the GitLab issuer | Rarely used for protecting external APIs; better suited for GitLab's own resources. |
|
||||
| **Google** | **JWT bearer (RFC 7523)** — *not* `client_credentials` | `https://oauth2.googleapis.com/token` | Signed assertion JWT carries `aud=https://oauth2.googleapis.com/token`; resulting access token is **opaque** unless you specifically request a Google-issued JWT for your API | Google service-account flow is not the best fit for this middleware (opaque tokens are rejected on the bearer path). Run Auth0 / Okta / Keycloak in front, or use ID-token-based flows on the cookie path. |
|
||||
|
||||
### RFC 7523 (JWT bearer assertion) — secretless alternative
|
||||
|
||||
When shared secrets are forbidden (FAPI, internal compliance), swap `client_secret` for a signed JWT assertion:
|
||||
|
||||
```
|
||||
POST /token
|
||||
grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer
|
||||
assertion=<JWT signed by the client's private key>
|
||||
```
|
||||
|
||||
The assertion JWT carries `iss=<client_id>`, `sub=<client_id>`, `aud=<token endpoint>`, `exp`. The IdP verifies the signature against a public key you've pre-registered and returns an access token.
|
||||
|
||||
This middleware already supports JWT assertions on the *middleware → IdP* hop via `clientAuthMethod: private_key_jwt` (see `docs/CONFIGURATION.md`). For the *client → IdP* hop, the same pattern applies — the client signs its own assertion.
|
||||
|
||||
### Operational notes
|
||||
|
||||
- **Token TTL is typically 1–24 hours.** Clients should refresh on `401`, not on a polling timer — saves the IdP.
|
||||
- **Cache and reuse tokens.** The middleware caches verified tokens too, so repeated presentations are cheap. Clients SHOULD reuse a token until ~80 % of `expires_in`.
|
||||
- **JWKS rotation is transparent.** The middleware auto-refreshes its JWKS cache when the IdP rotates keys. Clients don't need to do anything.
|
||||
- **Revocation is generally not per-token** with `client_credentials`. If you need real-time revocation, set `requireTokenIntrospection: true` on the middleware and the IdP is consulted on every cache miss.
|
||||
- **`scope` vs `audience`.** Scope says *what the client may do*; audience says *which service the token is for*. The middleware enforces audience; the backend service should enforce scope.
|
||||
- **Secret hygiene.** Store `client_secret` in a secrets manager (Vault, AWS Secrets Manager, Kubernetes `Secret`). For higher assurance, switch the client to `private_key_jwt` (no shared secret at all).
|
||||
|
||||
### Quickest validation loop
|
||||
|
||||
```bash
|
||||
# 1. Mint
|
||||
TOKEN=$(curl -s -X POST https://issuer.example.com/oauth/token \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"grant_type":"client_credentials","client_id":"…","client_secret":"…","audience":"https://api.example.com"}' \
|
||||
| jq -r .access_token)
|
||||
|
||||
# 2. Inspect claims to confirm aud/iss/exp match the middleware config
|
||||
echo "$TOKEN" | cut -d. -f2 | base64 -d 2>/dev/null | jq
|
||||
|
||||
# 3. Hit the protected route
|
||||
curl -i -H "Authorization: Bearer $TOKEN" https://api.example.com/protected
|
||||
```
|
||||
|
||||
`HTTP/1.1 200` with `X-Forwarded-User` on the backend confirms the loop works end-to-end. `401` with `WWW-Authenticate: Bearer error="invalid_token"` plus a middleware debug log explaining the rejection (audience mismatch, ID token presented, `iat` outside the 24h window, etc.) confirms the hardening is firing as designed.
|
||||
|
||||
## Threat model and design rules
|
||||
|
||||
Bearer authentication has materially different security properties from
|
||||
cookie sessions: no `HttpOnly`/`Secure`/`SameSite` shielding, the token is
|
||||
visible in headers and logs, and it's easier to exfiltrate. The bearer path
|
||||
treats every one of these as a first-class concern.
|
||||
|
||||
| Property | Behaviour | Why |
|
||||
|---|---|---|
|
||||
| Default state | `enableBearerAuth=false` | Bearer is opt-in; existing deployments observe no change. |
|
||||
| Audience | **Mandatory.** Startup fails if `audience` is empty when bearer is enabled. | Eliminates the "token issued for service B accepted by service A" confusion attack. |
|
||||
| Token format | JWT only (3 segments, JOSE-encoded). Opaque tokens are not accepted on the bearer path. | Matches the validation pipeline; opaque tokens require introspection only and bypass JWT-specific defences. |
|
||||
| `alg` allowlist | Hard-pinned asymmetric: `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. Checked **before** any JWKS fetch. | Denies `alg=none` and `alg=HS*` probes; prevents attacker noise from amplifying into JWKS round-trips. |
|
||||
| `kid` hardening | Max 256 bytes; charset `[A-Za-z0-9._\-=]`. Checked **before** JWKS fetch. | Prevents cache-key explosion / pathological-`kid` JWKS amplification. |
|
||||
| Token type | ID tokens are explicitly rejected (`nonce` claim, `typ: at+jwt`, `token_use=id`, scope/aud heuristics — reuses the existing `detectTokenType` helper). | ID tokens are not API credentials; treating them as such is classic token confusion. |
|
||||
| Multi-audience | When `aud` is an array of length > 1, the token must carry `azp == clientID`. | OIDC §2 hardening against tokens minted for one client being replayed by another. |
|
||||
| `iat` upper-age | Rejects tokens older than `maxTokenAgeSeconds` (default 24h). | Bounds clock-manipulation / forever-token abuse, even if `exp` is far in the future. |
|
||||
| Identifier claim | `bearerIdentifierClaim` (default `"sub"`). Resolved value drives `X-Forwarded-User`. | Decoupled from the cookie path's `UserIdentifierClaim` (default `email`) so the M2M flow can never accidentally trust an unverified email. |
|
||||
| Identifier sanitisation | Length cap (`maxIdentifierLength`, default 256). Rejects control chars, Unicode bidi-overrides (U+202A–U+202E, U+2066–U+2069), and the delimiters `, ; =`. | Defence in depth against downstream header injection / log injection / admin-UI spoofing. |
|
||||
| JTI replay marking | Bearer path skips the JTI **Set** (so the same token can be reused until `exp`) but the **Get** stays active. | Allows legitimate bearer reuse without false-positive replay detection; revoked tokens (added to the blacklist by `RevokeToken`) still fail immediately. |
|
||||
| Mixed bearer + cookie | **Cookie wins by default.** Flip to bearer-wins with `bearerOverridesCookie=true`. | Safer against browser/extension/proxy bearer injection scenarios. The cookie is the authoritative authenticator when present. |
|
||||
| `Authorization` strip | `stripAuthorizationHeader=true` by default. | Keeps the raw token out of downstream services and their logs. |
|
||||
| Excluded URLs | `Authorization` is stripped on excluded paths when `enableBearerAuth=true`. | Prevents bearer leakage into public health/metrics endpoint logs and prevents recon via excluded paths. |
|
||||
| Per-IP throttle | After `bearerFailureThreshold` consecutive 401s from one source IP within `bearerFailureWindowSeconds`, further bearer requests from that IP return `429 Too Many Requests` + `Retry-After` for `bearerFailurePenaltySeconds`. | Limits offline-guessing-style attacks and protects the shared rate-limiter / JWKS endpoint. |
|
||||
| Optional introspection | `requireTokenIntrospection=true` calls RFC 7662 introspection on every cache miss. Introspection result is cached briefly. Endpoint failure returns `503` (distinguishes infra outage from credential rejection). | Real-time revocation for high-assurance environments. Adds per-request IdP latency. |
|
||||
| Response shape | `401 Unauthorized` with generic body. `WWW-Authenticate: Bearer error="invalid_token"` per RFC 6750 §3 (toggleable via `bearerEmitWWWAuthenticate`). `403` for roles/groups denial. `429` for throttle. `503` for introspection-endpoint outage. | Auditable from spec to code; reason categories never leak into the response body. |
|
||||
| Logging | Failure reason + identifier hash (SHA-256 truncated to 8 hex chars) logged at debug. Raw tokens are never logged. | Audit trail without secrets-in-logs. |
|
||||
|
||||
## Configuration reference
|
||||
|
||||
| Field | Default | Description |
|
||||
|---|---|---|
|
||||
| `enableBearerAuth` | `false` | Master switch for the bearer path. |
|
||||
| `audience` | (unset) | **Required** when `enableBearerAuth=true`. Reuses the existing global `audience` field. |
|
||||
| `bearerIdentifierClaim` | `"sub"` | JWT claim used as the principal identifier. `"email"` is rejected at startup. |
|
||||
| `stripAuthorizationHeader` | `true` | Remove the `Authorization` header before forwarding to the backend. Disable only when a downstream needs to re-verify the bearer. |
|
||||
| `bearerEmitWWWAuthenticate` | `true` | Include `WWW-Authenticate: Bearer error="..."` on 401 responses (RFC 6750 §3). Disable to reduce recon signal. |
|
||||
| `bearerOverridesCookie` | `false` | Cookie wins when both are present (default). Set `true` for the AWS/GCP/Kubernetes bearer-wins convention. |
|
||||
| `maxTokenAgeSeconds` | `86400` | Upper bound on `iat` claim age (24h). Set `0` to disable the check (not recommended). |
|
||||
| `maxIdentifierLength` | `256` | Length cap for the post-sanitisation identifier. |
|
||||
| `bearerFailureThreshold` | `20` | Consecutive 401s from one IP that trip the throttle. |
|
||||
| `bearerFailureWindowSeconds` | `60` | Rolling window over which 401s are counted. |
|
||||
| `bearerFailurePenaltySeconds` | `60` | Duration of the 429 penalty box after the threshold trips. |
|
||||
| `requireTokenIntrospection` | `false` | Call RFC 7662 introspection on every cache miss. Adds per-request IdP latency. |
|
||||
|
||||
## What the bearer path does NOT do
|
||||
|
||||
- **Human-user / browser flows.** The bearer path is M2M-only in this
|
||||
iteration. Browser SPAs that want to attach a bearer to fetch calls work
|
||||
if your backend treats them as machine clients, but the spec defaults are
|
||||
tuned for service-to-service traffic.
|
||||
- **Opaque access tokens.** Tokens must be JWTs. Introspection is a
|
||||
revocation overlay on top of JWT verification, not a substitute for it.
|
||||
- **`email_verified` enforcement.** The bearer path rejects `email` as the
|
||||
identifier claim at startup precisely because `email_verified` is not
|
||||
enforced in this iteration. Adding human-user bearer support is a
|
||||
follow-up that must include this check.
|
||||
- **mTLS / API keys.** Out of scope. The `principal` abstraction enables
|
||||
adding these later as additional auth methods that produce a principal
|
||||
for the shared `forwardAuthorized` pipeline.
|
||||
- **SSE / WebSocket bypass with bearer.** Bypass paths keep their existing
|
||||
cookie-only behaviour; bearer headers are ignored on those endpoints.
|
||||
Documented limitation; widen by removing the bypass if you need bearer on
|
||||
streaming endpoints.
|
||||
|
||||
## Operational guidance
|
||||
|
||||
- **Always set `strictAudienceValidation: true` when bearer is enabled.**
|
||||
Startup logs a recommendation if you don't.
|
||||
- **Set a tight `maxTokenAgeSeconds`** for environments where tokens are
|
||||
expected to be minted frequently — the default 24h is conservative.
|
||||
- **Enable `requireTokenIntrospection`** if your IdP supports it and
|
||||
revocation latency matters. Bearer-path introspection caches results for
|
||||
a short window per token.
|
||||
- **Monitor 429s.** Sustained 429 traffic indicates either a buggy client
|
||||
loop or an active credential-stuffing attempt. The throttle is your
|
||||
primary signal for both.
|
||||
- **`stripAuthorizationHeader=false` extends the token's blast radius** to
|
||||
every downstream service that sees the request. Treat those services'
|
||||
logs as token stores.
|
||||
- **Bearer reuse is normal.** Don't enable per-token rate limiting; that's
|
||||
what `bearerFailureThreshold` is for (per-IP, not per-token).
|
||||
- **Cookie-wins is the safer default.** Only flip `bearerOverridesCookie`
|
||||
if you control all clients and have audited that none of them present a
|
||||
cookie alongside a bearer they don't intend to authenticate with.
|
||||
|
||||
## Failure response matrix
|
||||
|
||||
| Trigger | Status | Body | `WWW-Authenticate` |
|
||||
|---|---|---|---|
|
||||
| Empty bearer after prefix | 401 | `Unauthorized` | `Bearer error="invalid_request"` |
|
||||
| Token over `MaxLength` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Not a 3-segment JWT | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Disallowed `alg` (e.g. none, HS*) | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Missing / oversized / bad-charset `kid` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Signature / issuer / audience / `exp` failure | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| `iat` older than `maxTokenAgeSeconds` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Multi-audience token without matching `azp` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Detected as ID token | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| JTI blacklisted (revoked) | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Introspection reports `active=false` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Introspection endpoint failure | 503 | `Service Unavailable` | (none) |
|
||||
| Identifier claim missing / empty | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Identifier fails sanitisation | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
|
||||
| Per-IP failure threshold tripped | 429 | `Too Many Requests` | (none); `Retry-After: <bearerFailurePenaltySeconds>` |
|
||||
| Roles / groups not allowed | 403 | `Access denied` | (none) |
|
||||
|
||||
## Known follow-ups (deferred)
|
||||
|
||||
These are documented as future work, not blockers:
|
||||
|
||||
- **Human-user bearer with `email_verified` enforcement.** Requires
|
||||
decoupling the email-claim guard from the startup rejection and adding a
|
||||
per-request `email_verified=true` check.
|
||||
- **Introspection respects `client_assertion`.** The existing introspection
|
||||
helper uses `client_secret_basic` only; operators on `private_key_jwt`
|
||||
will see introspection silently use basic auth.
|
||||
- **Per-route bearer configuration.** Single middleware-wide setting in this
|
||||
iteration.
|
||||
|
||||
## References
|
||||
|
||||
- [PR design spec](superpowers/specs/2026-05-18-bearer-token-auth-design.md) — full design rationale, alternatives considered, and per-section sign-off history.
|
||||
- [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) — Bearer Token Usage.
|
||||
- [RFC 7662](https://www.rfc-editor.org/rfc/rfc7662) — OAuth 2.0 Token Introspection.
|
||||
- [RFC 9068](https://www.rfc-editor.org/rfc/rfc9068) — JWT Profile for OAuth 2.0 Access Tokens.
|
||||
+185
-8
@@ -5,6 +5,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
|
||||
## Table of Contents
|
||||
|
||||
- [Required Parameters](#required-parameters)
|
||||
- [Client Authentication](#client-authentication)
|
||||
- [Optional Parameters](#optional-parameters)
|
||||
- [Security Options](#security-options)
|
||||
- [Session Management](#session-management)
|
||||
@@ -22,7 +23,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
|
||||
|-----------|------|-------------|---------|
|
||||
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`. Optional when `clientAuthMethod: private_key_jwt`. | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
|
||||
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
|
||||
|
||||
@@ -45,6 +46,129 @@ spec:
|
||||
|
||||
---
|
||||
|
||||
## Client Authentication
|
||||
|
||||
The middleware supports three client authentication methods at the token and
|
||||
revocation endpoints. The default is `client_secret_post` (current behavior);
|
||||
`private_key_jwt` is opt-in and backwards compatible.
|
||||
|
||||
| Method | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `client_secret_post` | yes | `client_id` + `client_secret` in the request body. |
|
||||
| `client_secret_basic` | no | RFC 6749 §2.3.1 — `client_id` + `client_secret` in the `Authorization: Basic` header (form-urlencoded then base64); not in the body. |
|
||||
| `private_key_jwt` | no | RFC 7523 §2.2 — plugin signs a short-lived JWT with a private key and sends it as `client_assertion`. |
|
||||
|
||||
Select via `clientAuthMethod`:
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
```
|
||||
|
||||
### client_secret_post
|
||||
|
||||
Default. The plugin sends `client_id` and `client_secret` as form parameters
|
||||
in the token / revocation request body. No additional configuration required.
|
||||
|
||||
### private_key_jwt
|
||||
|
||||
Asymmetric client authentication per
|
||||
[RFC 7523 §2.2](https://www.rfc-editor.org/rfc/rfc7523). Use this when your
|
||||
IdP enforces short secret TTLs, when policy mandates secretless clients, or
|
||||
when you want to avoid distributing a shared secret to the proxy.
|
||||
|
||||
For each token / revocation request the plugin builds a JWS with:
|
||||
|
||||
- `iss` = `sub` = `clientID`
|
||||
- `aud` = token endpoint URL
|
||||
- `iat` = now, `exp` = now + 60s
|
||||
- `jti` = random hex per request
|
||||
- `kid` header = `clientAssertionKeyID`
|
||||
|
||||
**Required fields:**
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `clientAuthMethod` | string | `client_secret_post` | Set to `private_key_jwt`. |
|
||||
| `clientAssertionPrivateKey` | string | none | Inline PEM private key. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8, PKCS#1, and SEC1 formats accepted. |
|
||||
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk. Mutually exclusive with `clientAssertionPrivateKey`. |
|
||||
| `clientAssertionKeyID` | string | none | `kid` header inserted in the JWS. Must match the public key registered with the IdP. |
|
||||
| `clientAssertionAlg` | string | `RS256` | One of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `ES256`, `ES384`, `ES512`. |
|
||||
|
||||
When `clientAuthMethod: private_key_jwt`, `clientSecret` is optional.
|
||||
|
||||
**Example — inline PEM:**
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://idp.example.com
|
||||
clientID: my-client-id
|
||||
sessionEncryptionKey: your-32-byte-encryption-key-here
|
||||
callbackURL: /oauth2/callback
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyID: key-2026-01
|
||||
clientAssertionAlg: RS256
|
||||
clientAssertionPrivateKey: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7VJTUt9Us8cKj
|
||||
MZj4ev7QnMa1mYV3Kx1jRkH5YwXQ7N2J2j8K5pP6h0oZmXq1yQv4r8wZb3sH9D2k
|
||||
... (truncated) ...
|
||||
-----END PRIVATE KEY-----
|
||||
```
|
||||
|
||||
**Example — key on disk:**
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
|
||||
clientAssertionKeyID: key-2026-01
|
||||
clientAssertionAlg: RS256
|
||||
```
|
||||
|
||||
**Generating an RS256 key with OpenSSL:**
|
||||
|
||||
```bash
|
||||
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 \
|
||||
-out client-key.pem
|
||||
openssl rsa -in client-key.pem -pubout -out client-pub.pem
|
||||
```
|
||||
|
||||
Register `client-pub.pem` (or its JWK form) with your IdP under the same
|
||||
`kid` you set in `clientAssertionKeyID`.
|
||||
|
||||
**Notes:**
|
||||
|
||||
- The private key is parsed once at plugin startup. Key rotation requires a
|
||||
Traefik reload.
|
||||
- Assertion lifetime is fixed at 60 seconds.
|
||||
- A fresh random `jti` is generated per request.
|
||||
- The `aud` claim is the token endpoint URL (from discovery).
|
||||
- Tracking issue:
|
||||
[#135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
|
||||
|
||||
### client_secret_basic
|
||||
|
||||
Per [RFC 6749 §2.3.1][rfc6749-2-3-1], the plugin sends the client credentials
|
||||
in an `Authorization: Basic` header instead of the body. Both halves
|
||||
(`client_id`, `client_secret`) are form-urlencoded individually, joined with
|
||||
a colon, then base64-encoded. Use this when your IdP requires Basic auth at
|
||||
the token endpoint and rejects credentials in the body.
|
||||
|
||||
```yaml
|
||||
clientAuthMethod: client_secret_basic
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
```
|
||||
|
||||
[rfc6749-2-3-1]: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
|
||||
|
||||
---
|
||||
|
||||
## Optional Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
@@ -52,23 +176,55 @@ spec:
|
||||
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
|
||||
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
|
||||
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
|
||||
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
|
||||
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
|
||||
| `rateLimit` | int | `100` | Maximum requests per second |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication, matched at a path-segment or file-extension boundary |
|
||||
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
|
||||
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
|
||||
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
|
||||
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
|
||||
| `clientAuthMethod` | string | `client_secret_post` | Client authentication method at token/revocation endpoints. One of `client_secret_post`, `client_secret_basic`, `private_key_jwt`. See [Client Authentication](#client-authentication). |
|
||||
| `clientAssertionPrivateKey` | string | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8 / PKCS#1 / SEC1. |
|
||||
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk for `private_key_jwt`. Mutually exclusive with `clientAssertionPrivateKey`. |
|
||||
| `clientAssertionKeyID` | string | none | `kid` header for `private_key_jwt` assertions. Required when `clientAuthMethod: private_key_jwt`. |
|
||||
| `clientAssertionAlg` | string | `RS256` | Signing algorithm for `private_key_jwt`. One of `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
|
||||
|
||||
### TLS Termination at Load Balancer
|
||||
|
||||
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
|
||||
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
|
||||
the correct default behind any TLS-terminating load balancer (AWS ALB, Google
|
||||
Cloud LB, Azure App Gateway) — `X-Forwarded-Proto` cannot be trusted (ALB may
|
||||
overwrite it).
|
||||
|
||||
```yaml
|
||||
forceHTTPS: true # Required for correct redirect URIs
|
||||
```
|
||||
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
|
||||
dev). Otherwise leave it at default.
|
||||
|
||||
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
|
||||
### Streaming Endpoints (SSE and WebSocket)
|
||||
|
||||
The middleware automatically bypasses the OIDC redirect for two request kinds
|
||||
that browsers cannot follow a 302 on:
|
||||
|
||||
| Bypass | Triggered by |
|
||||
|--------|--------------|
|
||||
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
|
||||
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
|
||||
|
||||
These requests do **not** require any explicit configuration — they are
|
||||
handled implicitly. However, the bypass is **not** unauthenticated:
|
||||
|
||||
- A valid, encrypted session cookie is required. Requests without one are
|
||||
rejected (the connection cannot proceed to the backend).
|
||||
- The session cookie is sealed with `sessionEncryptionKey`, so the
|
||||
`authenticated` flag cannot be forged.
|
||||
- Validation is cookie-only — no JWK fetch / signature verification — so
|
||||
streaming endpoints keep working when the OIDC provider is briefly
|
||||
unavailable.
|
||||
- The user identifier from the session is forwarded as `X-Forwarded-User`
|
||||
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
|
||||
|
||||
For browser clients, the user must complete the normal OIDC flow on a
|
||||
regular HTTP page first; the resulting session cookie is then reused on the
|
||||
SSE / WebSocket connection.
|
||||
|
||||
---
|
||||
|
||||
@@ -105,6 +261,26 @@ strictAudienceValidation: true
|
||||
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
|
||||
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
|
||||
|
||||
### Bearer-token (M2M) authentication
|
||||
|
||||
Opt-in path that accepts `Authorization: Bearer <jwt>` instead of the cookie
|
||||
session flow. M2M-only, default off, audience-mandatory. See
|
||||
[docs/BEARER_AUTH.md](BEARER_AUTH.md) for the threat model and operational
|
||||
guidance.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `enableBearerAuth` | bool | `false` | Master switch. Startup fails if true with empty `audience` or with `bearerIdentifierClaim=email`. |
|
||||
| `bearerIdentifierClaim` | string | `"sub"` | JWT claim used as the principal identifier. `"email"` is rejected at startup. |
|
||||
| `stripAuthorizationHeader` | bool | `true` | Strip `Authorization` from forwarded requests after successful bearer auth. |
|
||||
| `bearerEmitWWWAuthenticate` | bool | `true` | Emit RFC 6750 `WWW-Authenticate: Bearer error="..."` hints on 401. |
|
||||
| `bearerOverridesCookie` | bool | `false` | Cookie wins when both bearer and cookie are present (default). Set true for bearer-wins. |
|
||||
| `maxTokenAgeSeconds` | int64 | `86400` | Upper bound on `iat` claim age (24h). 0 disables the check. |
|
||||
| `maxIdentifierLength` | int | `256` | Length cap on the sanitised principal identifier. |
|
||||
| `bearerFailureThreshold` | int | `20` | Consecutive 401s from one source IP that trip the throttle. |
|
||||
| `bearerFailureWindowSeconds` | int | `60` | Rolling window for counting 401s. |
|
||||
| `bearerFailurePenaltySeconds` | int | `60` | 429 + `Retry-After` duration after the threshold trips. |
|
||||
|
||||
---
|
||||
|
||||
## Session Management
|
||||
@@ -113,6 +289,7 @@ strictAudienceValidation: true
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
|
||||
+95
@@ -0,0 +1,95 @@
|
||||
# Dynamic Client Registration (RFC 7591)
|
||||
|
||||
The middleware can register itself with an OIDC provider at startup instead of
|
||||
using a pre-provisioned `clientID` / `clientSecret`. Useful for multi-tenant
|
||||
deployments, self-service integrations, and ephemeral environments.
|
||||
|
||||
## How it works
|
||||
|
||||
1. Middleware reads `registration_endpoint` from `.well-known/openid-configuration`.
|
||||
2. If `clientID` is empty, it `POST`s `clientMetadata` to the registration endpoint.
|
||||
3. Returned `client_id` / `client_secret` are cached, optionally persisted.
|
||||
4. Subsequent requests use the registered credentials.
|
||||
|
||||
For multi-replica deployments, set `storageBackend: redis` so all replicas
|
||||
share one client and avoid registration races.
|
||||
|
||||
## Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-dcr
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://your-oidc-provider.com
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
persistCredentials: true
|
||||
storageBackend: redis # file | redis | auto
|
||||
initialAccessToken: "" # optional, for protected endpoints
|
||||
registrationEndpoint: "" # optional, override discovery
|
||||
credentialsFile: /tmp/oidc-client-credentials.json
|
||||
redisKeyPrefix: "dcr:creds:"
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- https://app.example.com/oauth2/callback
|
||||
client_name: My Application
|
||||
application_type: web
|
||||
grant_types: [authorization_code, refresh_token]
|
||||
response_types: [code]
|
||||
token_endpoint_auth_method: client_secret_basic
|
||||
contacts: [admin@example.com]
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `enabled` | `false` | Enable DCR. |
|
||||
| `persistCredentials` | `false` | Save returned credentials for reuse across restarts. |
|
||||
| `storageBackend` | `auto` | `file`, `redis`, or `auto` (Redis if available, else file). |
|
||||
| `credentialsFile` | `/tmp/oidc-client-credentials.json` | Path for file-backed storage. Mode `0600`. |
|
||||
| `redisKeyPrefix` | (none — set explicitly) | Key prefix for Redis-backed storage. The code does not inject a default; if unset, keys have no prefix. `dcr:creds:` is a sensible convention. |
|
||||
| `registrationEndpoint` | discovered | Override the discovered endpoint. |
|
||||
| `initialAccessToken` | none | Bearer token for protected registration endpoints. |
|
||||
| `clientMetadata.redirect_uris` | required | Callback URIs for the OAuth flow. |
|
||||
| `clientMetadata.client_name` | none | Human-readable client name. |
|
||||
| `clientMetadata.application_type` | `web` | `web` or `native`. |
|
||||
| `clientMetadata.grant_types` | `[authorization_code, refresh_token]` | OAuth grant types. |
|
||||
| `clientMetadata.response_types` | `[code]` | OAuth response types. |
|
||||
| `clientMetadata.token_endpoint_auth_method` | `client_secret_basic` | `client_secret_basic`, `client_secret_post`, or `none`. |
|
||||
| `clientMetadata.scope` | none | Space-separated scopes. |
|
||||
| `clientMetadata.contacts` | none | Admin email addresses. |
|
||||
| `clientMetadata.logo_uri` | none | Logo URL for consent screens. |
|
||||
| `clientMetadata.client_uri` | none | Client homepage URL. |
|
||||
| `clientMetadata.policy_uri` | none | Privacy policy URL. |
|
||||
| `clientMetadata.tos_uri` | none | Terms of service URL. |
|
||||
|
||||
## Provider support
|
||||
|
||||
The middleware does not gate DCR by provider — if the provider exposes a
|
||||
`registration_endpoint` in its discovery document (or you set
|
||||
`registrationEndpoint` explicitly), DCR will attempt registration. The table
|
||||
below is informational guidance based on each provider's published support.
|
||||
|
||||
| Provider | DCR | Notes |
|
||||
|----------|-----|-------|
|
||||
| Keycloak | Yes | Enable in realm settings. |
|
||||
| Auth0 | Yes | Requires Management API token. |
|
||||
| Okta | Yes | Enable Dynamic Client Registration in admin console. |
|
||||
| Azure AD | Limited | Use App Registration API instead. |
|
||||
| Google | No | Manual registration required. |
|
||||
| AWS Cognito | No | Manual registration required. |
|
||||
|
||||
## Security notes
|
||||
|
||||
- Registration endpoints must be HTTPS (loopback excepted for local dev).
|
||||
- Use `initialAccessToken` in production to gate registration.
|
||||
- File-backed credentials use `0600`; protect the mount path.
|
||||
- The plugin marks credentials invalid when within ~5 min of `client_secret_expires_at` but does **not** automatically re-register. If your provider sets a non-zero expiry, schedule manual rotation (delete the credentials file or Redis entry, restart) before that time.
|
||||
+20
-99
@@ -16,9 +16,8 @@ Guide for local development, testing, and contributing to the Traefik OIDC middl
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Go 1.23+** for plugin compilation
|
||||
- **Docker & Docker Compose** for local testing
|
||||
- **OIDC Provider** credentials (Google, Azure, etc.)
|
||||
- **Go 1.24+** (matches `go.mod`; CI runs Go 1.24.11)
|
||||
- **OIDC Provider** credentials (Google, Azure, etc.) for any end-to-end test against a real provider
|
||||
|
||||
### Required Development Tools
|
||||
|
||||
@@ -40,110 +39,32 @@ go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
## Local Development Setup
|
||||
|
||||
### Docker Compose Environment
|
||||
|
||||
The repository includes a Docker Compose setup for testing the plugin locally.
|
||||
|
||||
#### 1. Host Configuration
|
||||
|
||||
Add to `/etc/hosts`:
|
||||
### Build and unit tests
|
||||
|
||||
```bash
|
||||
127.0.0.1 hello.localhost
|
||||
127.0.0.1 traefik.localhost
|
||||
go mod tidy
|
||||
go build ./...
|
||||
go test ./... -short # fast loop, < 30 s
|
||||
go test -race -timeout=15m ./...
|
||||
```
|
||||
|
||||
#### 2. Plugin Configuration
|
||||
### Sample plugin configurations
|
||||
|
||||
The plugin is loaded using Traefik's **local plugins mode**:
|
||||
Working middleware/Traefik configs live in [`examples/`](../examples/):
|
||||
|
||||
- Plugin source: Parent directory (`../`)
|
||||
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
|
||||
- Configuration: `experimental.localPlugins` in `traefik.yml`
|
||||
- `complete-traefik-config.yaml` — full middleware example
|
||||
- `redis-config.yaml` — Redis cache configuration
|
||||
|
||||
#### 3. OIDC Provider Setup
|
||||
To run the plugin against a real Traefik instance, drop the project on disk
|
||||
and load it via `experimental.localPlugins` in your Traefik static config —
|
||||
see the [README install section](../README.md#install).
|
||||
|
||||
Edit `docker/dynamic.yml` with your provider details:
|
||||
### Integration tests
|
||||
|
||||
**Google:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://accounts.google.com"
|
||||
clientID: "your-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
logoutURL: "/oauth2/logout"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
**Azure AD:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
#### 4. Start Environment
|
||||
Integration tests live in `integration/`. Run them explicitly:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### 5. Test Plugin
|
||||
|
||||
- **Protected App**: http://hello.localhost (redirects to OIDC)
|
||||
- **Traefik Dashboard**: http://traefik.localhost:8080
|
||||
|
||||
### Development Workflow
|
||||
|
||||
1. **Edit plugin code** in the project root
|
||||
2. **Build and test** (optional syntax check):
|
||||
```bash
|
||||
go mod tidy
|
||||
go build .
|
||||
go test ./...
|
||||
```
|
||||
3. **Restart Traefik** to reload plugin:
|
||||
```bash
|
||||
docker-compose restart traefik
|
||||
```
|
||||
4. **Test changes** at http://hello.localhost
|
||||
|
||||
### Debugging
|
||||
|
||||
**View plugin logs:**
|
||||
```bash
|
||||
docker-compose logs -f traefik | grep traefikoidc
|
||||
```
|
||||
|
||||
**Check plugin loading:**
|
||||
```bash
|
||||
docker-compose logs traefik | grep -i plugin
|
||||
```
|
||||
|
||||
**Verify plugin directory:**
|
||||
```bash
|
||||
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
|
||||
go test ./integration/... -run Integration -v
|
||||
```
|
||||
|
||||
---
|
||||
@@ -299,7 +220,7 @@ The repository uses GitHub Actions for comprehensive validation with 20+ paralle
|
||||
|
||||
#### Testing (9 suites)
|
||||
- Race Detector
|
||||
- Coverage (75% threshold)
|
||||
- Coverage (70% threshold, enforced in `pr.yaml`)
|
||||
- Memory Leaks
|
||||
- Integration Tests
|
||||
- Regression Tests
|
||||
@@ -323,13 +244,13 @@ Tests run in parallel for:
|
||||
#### Performance & Build (3 checks)
|
||||
- Benchmarks
|
||||
- Multi-platform Build (linux/darwin x amd64/arm64)
|
||||
- Go Version Compatibility (Go 1.23 & 1.24)
|
||||
- Go Version Compatibility (currently Go 1.24.11 in CI)
|
||||
|
||||
### Quality Gates
|
||||
|
||||
All PRs must pass:
|
||||
- All parallel checks
|
||||
- 75% test coverage minimum
|
||||
- 70% test coverage minimum
|
||||
- Zero security vulnerabilities
|
||||
- No race conditions
|
||||
- No memory leaks
|
||||
|
||||
+3
-3
@@ -23,10 +23,10 @@ Configuration reference for each supported OIDC provider.
|
||||
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|
||||
|----------|-------------|----------------|----------------|-----------|
|
||||
| Google | Full | Yes | `accounts.google.com` | Yes |
|
||||
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
|
||||
| Azure AD | Full | Yes | `login.microsoftonline.com`, `sts.windows.net` | Yes |
|
||||
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
|
||||
| Okta | Full | Yes | `*.okta.com` | Yes |
|
||||
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
|
||||
| Okta | Full | Yes | `*.okta.com`, `*.oktapreview.com`, `*.okta-emea.com` | Yes |
|
||||
| Keycloak | Full | Yes | host containing `keycloak`, or `/realms/` in path (matches both `/auth/realms/` legacy and `/realms/` modern) | Yes |
|
||||
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
|
||||
| GitLab | Full | Yes | `gitlab.com` | Yes |
|
||||
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
|
||||
|
||||
+14
-6
@@ -109,11 +109,11 @@ redis:
|
||||
| `writeTimeout` | int | `3` | Write timeout (seconds) |
|
||||
| `enableTLS` | bool | `false` | Enable TLS for connections |
|
||||
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
|
||||
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
|
||||
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
|
||||
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
|
||||
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
|
||||
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
|
||||
| `enableCircuitBreaker` | bool | `false` | Wrap the Redis backend with a circuit breaker. **Recommended `true` in production.** |
|
||||
| `circuitBreakerThreshold` | int | `5` | Consecutive failures before the circuit opens (only when `enableCircuitBreaker: true`). |
|
||||
| `circuitBreakerTimeout` | int | `60` | Seconds the circuit stays open before allowing a probe (only when `enableCircuitBreaker: true`). |
|
||||
| `enableHealthCheck` | bool | `false` | Wrap the Redis backend with periodic health checks. **Recommended `true` in production.** |
|
||||
| `healthCheckInterval` | int | `30` | Health check interval in seconds (only when `enableHealthCheck: true`). |
|
||||
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
|
||||
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
|
||||
|
||||
@@ -134,13 +134,21 @@ REDIS_READ_TIMEOUT=3
|
||||
REDIS_WRITE_TIMEOUT=3
|
||||
REDIS_ENABLE_TLS=false
|
||||
REDIS_TLS_SKIP_VERIFY=false
|
||||
REDIS_HYBRID_L1_SIZE=500
|
||||
REDIS_HYBRID_L1_MEMORY_MB=10
|
||||
```
|
||||
|
||||
> Resilience fields (`enableCircuitBreaker`, `enableHealthCheck`,
|
||||
> `circuitBreakerThreshold`, `circuitBreakerTimeout`, `healthCheckInterval`)
|
||||
> have no environment variable fallback — set them in plugin configuration.
|
||||
|
||||
Invalid `cacheMode` values are rejected at plugin startup.
|
||||
|
||||
---
|
||||
|
||||
## Cache Modes
|
||||
|
||||
### Memory Mode (Default without Redis)
|
||||
### Memory Mode (used when Redis is disabled)
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
|
||||
+2
-2
@@ -6,8 +6,8 @@ Comprehensive testing infrastructure for traefikoidc.
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Test files | 99 |
|
||||
| Lines of test code | ~65,500 |
|
||||
| Test files | 110 |
|
||||
| Lines of test code | ~72,000 |
|
||||
| Code coverage | 71.0% |
|
||||
| Race conditions | None (all pass with `-race`) |
|
||||
|
||||
|
||||
+46
-3
@@ -642,7 +642,7 @@ spec:
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientSecret</code></td>
|
||||
<td class="py-2 px-3">OAuth 2.0 client secret</td>
|
||||
<td class="py-2 px-3">OAuth 2.0 client secret. Only required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is unset or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">sessionEncryptionKey</code></td>
|
||||
@@ -718,6 +718,11 @@ spec:
|
||||
<td class="py-2 px-3">86400</td>
|
||||
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
|
||||
<td class="py-2 px-3">21600</td>
|
||||
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
|
||||
<td class="py-2 px-3">_oidc_raczylo_</td>
|
||||
@@ -748,15 +753,48 @@ spec:
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Require RFC 7662 introspection for opaque tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">disableReplayDetection</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Disable JTI replay detection (for multi-replica without Redis)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code></td>
|
||||
<td class="py-2 px-3">client_secret_post</td>
|
||||
<td class="py-2 px-3">Selects how the plugin authenticates to the token endpoint. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">Inline PEM private key used to sign client assertions for <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyPath</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">Path to a PEM private key file. Alternative to <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code>.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyID</code></td>
|
||||
<td class="py-2 px-3">none</td>
|
||||
<td class="py-2 px-3">JWS <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">kid</code> header value. Required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionAlg</code></td>
|
||||
<td class="py-2 px-3">RS256</td>
|
||||
<td class="py-2 px-3">Signing algorithm for the client assertion. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES512</code>.</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Private Key JWT (RFC 7523)</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Use this when your IdP (Entra ID, Okta, Auth0, Keycloak) pressures short-lived secrets, or when policy mandates secretless service-to-service authentication. The plugin signs a 60-second assertion with the configured private key and sends it as <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_assertion</code> instead of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret</code>. Public-key registration on the IdP replaces shared-secret rotation. See <a href="https://www.rfc-editor.org/rfc/rfc7523" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">RFC 7523</a> and <a href="https://github.com/lukaszraczylo/traefikoidc/issues/135" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">issue #135</a>.</p>
|
||||
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>clientAuthMethod: private_key_jwt
|
||||
clientAssertionKeyPath: /etc/traefik/oidc-client.pem
|
||||
clientAssertionKeyID: my-client-key-2026
|
||||
# clientSecret no longer required</code></pre>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Example: Google Workspace with Domain Restriction</h3>
|
||||
|
||||
@@ -858,7 +896,12 @@ spec:
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -0,0 +1,459 @@
|
||||
# Bearer Token Authentication — Design Spec
|
||||
|
||||
- **Date**: 2026-05-18
|
||||
- **Status**: Design — pending implementation plan
|
||||
- **Supersedes**: PR #93 (broken implementation; recommended to close in favour of this design)
|
||||
|
||||
## 1. Summary
|
||||
|
||||
Add an opt-in path that lets API clients (machine-to-machine) authenticate by presenting a signed access token in the `Authorization: Bearer <token>` header, bypassing the cookie-based OIDC redirect flow. Identity, roles, and authorization checks remain consistent with the existing cookie path; the only thing that changes is how the principal is established for that single request.
|
||||
|
||||
The feature is implemented by extracting a shared `forwardAuthorized` pipeline from the existing `processAuthorizedRequest`, introducing a `principal` value type, and adding a small bearer-specific entrypoint that builds a principal directly from a verified JWT — without synthesising a fake `SessionData`.
|
||||
|
||||
## 2. Motivation
|
||||
|
||||
PR #93 attempted this feature by building an in-memory `SessionData` from JWT claims and reusing `processAuthorizedRequest`. The approach has three latent defects:
|
||||
|
||||
1. The synthetic session omits `mainSession.Values["user_identifier"]`. `processAuthorizedRequest` reads it via `GetUserIdentifier()`; when empty it bails to `defaultInitiateAuthentication` and issues an OIDC redirect. The feature is non-functional in practice despite the unit test passing.
|
||||
2. `verifyToken` accepts both ID tokens (audience match against `clientID`) and access tokens. ID tokens are not API credentials; treating them as such is a classic token-confusion vector.
|
||||
3. `verifyToken` adds JTI to the replay blacklist on first verify. Once the verified-token cache evicts, subsequent reuse of the same bearer token triggers a false-positive replay rejection.
|
||||
|
||||
Rather than patch a synthetic-session approach that will keep generating bugs as `SessionData` evolves, this spec replaces it with a cleaner abstraction where session lifecycle and post-auth header injection live in separate units.
|
||||
|
||||
## 3. Goals
|
||||
|
||||
- Accept `Authorization: Bearer <jwt>` from M2M clients, validate the token, and forward the request downstream with identity headers populated.
|
||||
- Enforce the same `allowedRolesAndGroups` policy as the cookie path.
|
||||
- Default-off; safe defaults when enabled (audience required, ID tokens rejected, identifier sanitised).
|
||||
- No behavioural change to the cookie path. Existing tests must continue to pass without modification.
|
||||
|
||||
## 4. Non-Goals
|
||||
|
||||
- Human-user / browser flows. Bearer is M2M-only in this iteration.
|
||||
- Pure opaque access tokens on the bearer path. Tokens must be JWTs; introspection (RFC 7662) is supported *on top of* JWT verification for revocation state, not as a substitute for it.
|
||||
- mTLS, API keys, or any other auth method. The `principal` abstraction enables them later, but they are not delivered here.
|
||||
- Per-route bearer configuration. Single middleware-wide setting.
|
||||
|
||||
## 5. Decided Requirements
|
||||
|
||||
| Topic | Decision |
|
||||
|---|---|
|
||||
| Consumer type | Machine-to-machine (M2M) only |
|
||||
| Token format | JWT only (signature, issuer, audience, exp) |
|
||||
| Audience | Mandatory when feature enabled; startup fails if `Audience == ""` |
|
||||
| Token type | Access tokens only; ID tokens explicitly rejected |
|
||||
| Revocation | JWT-only verification by default; introspection (RFC 7662) opt-in via existing `RequireTokenIntrospection` |
|
||||
| Identity claim | New `BearerIdentifierClaim` config (string, default `"sub"`). Bearer path reads this claim exclusively; does NOT use `UserIdentifierClaim` (which defaults to `"email"` and drives the cookie path). Resolved value must be a non-empty string. `sub` is mandatory per `jwt.go:416` regardless, so even with a different `BearerIdentifierClaim` the token must still carry a valid `sub`. Decoupling avoids the M2M-vs-human-user identity-claim conflict and the email-spoofing footgun. |
|
||||
| Identifier sanitisation | Reject value containing any `unicode.IsControl` char, any Unicode bidi-override (U+202A–U+202E, U+2066–U+2069), leading/trailing whitespace, commas, semicolons, equals signs. Max length 256 bytes. |
|
||||
| Token classifier | **Reuse existing `detectTokenType(jwt, token)` at `token_manager.go:187-303`** which already handles `nonce`, `typ: at+jwt`, `token_use`, `scope`, and aud-vs-clientID priority. Bearer path rejects any token where `detectTokenType == true` (ID token). Do not invent a parallel classifier. |
|
||||
| Algorithm pinning | Hard-pin `alg ∈ {RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512}`, enforced **before** JWKS lookup on the bearer path. Prevents wasted JWKS fetches for `alg=none`/HS attacker probes. |
|
||||
| `kid` hardening | `kid` ≤ 256 bytes, charset `[A-Za-z0-9._\-=]`. Reject before JWKS lookup. |
|
||||
| Token age | Bearer path enforces `now - iat <= MaxTokenAgeSeconds` (default 86400 / 24h, configurable). Cookie path unchanged. |
|
||||
| Multi-audience policy | If `aud` is an array (length > 1), require `azp` claim to be present and equal to `clientID`. Single-string `aud` unaffected. |
|
||||
| Mixed bearer + cookie precedence | **Cookie wins by default** when both are presented (safer for browser scenarios). Operator opt-in: `BearerOverridesCookie=true` to flip. Either way, a warning is logged on the request. |
|
||||
| Bearer + excluded URL | `Authorization` header is **stripped** before forwarding when the request hits an excluded URL. Prevents bearer leaking into public endpoints' downstream logs and prevents recon via excluded paths. |
|
||||
| Per-source bearer 401 throttle | New sharded cache `failedBearerAttempts` keyed by client IP. After N (default 20) consecutive 401s from one IP within 1 minute, reject further bearer requests from that IP with 429 for 60s. Applied BEFORE `verifyToken` to deny JWKS amplification. |
|
||||
| `Authorization` header passthrough | New `StripAuthorizationHeader` config, default `true` |
|
||||
| Roles/groups gating | Same `allowedRolesAndGroups` rules as cookie path |
|
||||
| Default state | `EnableBearerAuth` = `false` |
|
||||
| JTI replay marking | Suppressed on bearer path; cookie path unchanged |
|
||||
| Failure response shape | 401 with generic body; `WWW-Authenticate: Bearer error="invalid_token"` per RFC 6750 |
|
||||
| Introspection endpoint outage | 503 (distinguishes infra outage from token rejection) |
|
||||
| Mixed bearer + cookie | Bearer wins; cookie ignored on that request |
|
||||
| SSE/WS bypass + bearer | Bypass paths keep cookie-only check; bearer header ignored on SSE/WS |
|
||||
|
||||
## 6. Architecture
|
||||
|
||||
```
|
||||
┌──────────────────┐
|
||||
HTTP req ──► │ ServeHTTP │ (existing entry; adds bearer detection)
|
||||
└─────────┬────────┘
|
||||
┌───────────┴────────────┐
|
||||
▼ ▼
|
||||
cookie / session bearer (Authorization: Bearer …)
|
||||
│ │
|
||||
▼ ▼
|
||||
┌────────────────┐ ┌────────────────────┐
|
||||
│ buildPrincipal │ │ buildPrincipal │
|
||||
│ FromSession() │ │ FromBearerToken() │
|
||||
└────────┬───────┘ └─────────┬──────────┘
|
||||
│ produces *principal │
|
||||
└──────────────┬───────────┘
|
||||
▼
|
||||
┌────────────────────────────┐
|
||||
│ forwardAuthorized(rw,req,p)│ (shared pipeline)
|
||||
│ • roles/groups gate │
|
||||
│ • header injection │
|
||||
│ • header templates │
|
||||
│ • security headers │
|
||||
│ • cookie stripping │
|
||||
│ • next.ServeHTTP │
|
||||
└────────────────────────────┘
|
||||
```
|
||||
|
||||
**Invariant**: `forwardAuthorized` never touches session storage. Session-specific concerns (Save, IsDirty, backchannel-logout invalidation) stay inside `processAuthorizedRequest` around the call to `forwardAuthorized`.
|
||||
|
||||
**Feature gate**: when `EnableBearerAuth == false`, the bearer-detection check in `ServeHTTP` is a no-op. Existing deployments observe byte-identical behaviour.
|
||||
|
||||
## 7. Components
|
||||
|
||||
### 7.1 `principal` type (new file `principal.go`)
|
||||
|
||||
```go
|
||||
type principalSource int
|
||||
|
||||
const (
|
||||
sourceSession principalSource = iota
|
||||
sourceBearer
|
||||
)
|
||||
|
||||
type principal struct {
|
||||
Identifier string // drives X-Forwarded-User
|
||||
Email string // optional, "" for M2M
|
||||
Subject string // sub claim
|
||||
ClientID string // azp / client_id, M2M caller
|
||||
Claims map[string]interface{} // raw claims for templates / groups
|
||||
AccessToken string // for X-Auth-Request-Token (gated by minimalHeaders)
|
||||
IDToken string // "" on bearer path
|
||||
RefreshToken string // "" on bearer path
|
||||
Source principalSource
|
||||
}
|
||||
```
|
||||
|
||||
Pure data. No methods that mutate it. No I/O. No manager pointer.
|
||||
|
||||
### 7.2 `buildPrincipalFromSession(*SessionData) *principal` (new in `principal.go`)
|
||||
|
||||
Read-only adapter over existing `SessionData` getters: `GetUserIdentifier`, `GetEmail`, `GetAccessToken`, `GetIDToken`, `GetRefreshToken`, cached claims via `GetIDTokenClaims`. Does not write back to the session. This is the only function that still knows about `SessionData`.
|
||||
|
||||
### 7.3 `buildPrincipalFromBearerToken(token string) (*principal, error)` (new in `bearer_auth.go`)
|
||||
|
||||
1. **Length / format guards**: `len(token) <= AccessTokenConfig.MaxLength`, exactly two dots, non-empty after trim.
|
||||
2. **Parse header for early alg/kid pinning** (without trusting payload): decode JOSE header; reject if `alg` ∉ asymmetric allowlist; reject if `kid` missing, > 256 bytes, or contains chars outside `[A-Za-z0-9._\-=]`. This happens **before** JWKS lookup so attacker noise doesn't amplify into JWKS fetches.
|
||||
3. **Per-IP 401 throttle check**: if this IP is in the `failedBearerAttempts` penalty box, return 429 immediately.
|
||||
4. `t.verifyToken(token, verifyOpts{skipReplayMarking: true})` — reuses signature, issuer, audience, expiration, JTI Get (replay detection). The `skipReplayMarking` flag gates ONLY the JTI Set at `token_manager.go:108-143`; the JTI Get at `token_manager.go:44-47, 80-89` remains active so revoked tokens (via `RevokeToken` adding to blacklist) are still rejected.
|
||||
5. **Re-parse claims** (`parseJWT(token)` is cheap and already done internally; reuse via a single decode if practical).
|
||||
6. **Token-type guard**: call existing `detectTokenType(jwt, token)` (`token_manager.go:187-303`). Reject when it returns `true` (ID token). Belt-and-braces: also reject if `claims["nonce"]` is a non-empty string or `claims["token_use"] == "id"`.
|
||||
7. **Multi-audience hardening**: if `claims["aud"]` is a `[]interface{}` with length > 1, require `claims["azp"]` to be a non-empty string equal to `t.clientID`; reject otherwise.
|
||||
8. **`iat` upper-age bound**: reject when `time.Now().Unix() - int64(claims["iat"].(float64)) > MaxTokenAgeSeconds` (default 86400).
|
||||
9. **Optional introspection**: if `requireTokenIntrospection` is set, call `introspectToken`; reject if `active == false` (401); surface 503 on transport failure. Bearer-path introspection cache TTL is capped at 60s (not 5min) to keep the "real-time revocation" promise close to true.
|
||||
10. **Identifier resolution**: read `t.bearerIdentifierClaim` (defaults to `"sub"`); do NOT use `t.userIdentifierClaim` (cookie path's setting, default `email`). The bearer path does NOT fall back to other claims because `jwt.Verify` already enforces non-empty `sub` (`jwt.go:416-419`). Empty/missing identifier → 401.
|
||||
11. **Identifier sanitisation**: trim, then reject if length > 256 OR contains any of: `unicode.IsControl`, bidi-override (U+202A–U+202E, U+2066–U+2069), `,`, `;`, `=`.
|
||||
12. Return `&principal{ Source: sourceBearer, … }`.
|
||||
|
||||
On any failure path: increment the per-IP `failedBearerAttempts` counter; return the appropriate HTTP status (401 / 403 / 429 / 503) without revealing the failure reason in the response body. Reason is logged at debug only, with the identifier (if resolved) hashed via SHA-256 truncated to 8 hex chars.
|
||||
|
||||
### 7.4 `forwardAuthorized(rw, req, *principal)` (new in `middleware.go`, extracted)
|
||||
|
||||
The shared post-auth pipeline. Lifted verbatim from the existing `processAuthorizedRequest`:
|
||||
|
||||
1. Roles/groups extraction via existing `extractGroupsAndRolesFromClaims`.
|
||||
2. `allowedRolesAndGroups` gate (existing logic).
|
||||
3. Inject `X-Forwarded-User`, `X-User-Groups`, `X-User-Roles`.
|
||||
4. Inject `X-Auth-Request-*` (gated by `minimalHeaders`).
|
||||
5. Header templates.
|
||||
6. Security headers.
|
||||
7. Cookie strip when `stripAuthCookies`.
|
||||
8. **New**: `Authorization` header strip when `stripAuthorizationHeader` AND `principal.Source == sourceBearer`.
|
||||
9. `t.next.ServeHTTP(rw, req)`.
|
||||
|
||||
Does not call `Save`, does not check `IsDirty`. Session persistence stays with the cookie-path caller.
|
||||
|
||||
### 7.5 `handleBearerRequest(rw, req)` (new in `bearer_auth.go`)
|
||||
|
||||
```
|
||||
1. Detect "Authorization: Bearer <token>" (case-insensitive prefix).
|
||||
2. token = TrimSpace(authHeader[7:]); reject empty.
|
||||
3. p, err := buildPrincipalFromBearerToken(token).
|
||||
On err → 401 with WWW-Authenticate, log reason at debug.
|
||||
4. forwardAuthorized(rw, req, p).
|
||||
```
|
||||
|
||||
Target: ~40 lines.
|
||||
|
||||
### 7.6 Refactor of `processAuthorizedRequest` (modify `middleware.go`)
|
||||
|
||||
Splits along the principal boundary:
|
||||
- Session-specific part (backchannel-logout invalidation, `IsDirty` / `Save`) stays in `processAuthorizedRequest`.
|
||||
- Everything else moves to `forwardAuthorized`.
|
||||
- `processAuthorizedRequest` ends with `forwardAuthorized(rw, req, buildPrincipalFromSession(session))`.
|
||||
|
||||
### 7.7 `verifyOpts` extension to `verifyToken` (modify `token_manager.go`)
|
||||
|
||||
Add a parameter struct:
|
||||
```go
|
||||
type verifyOpts struct {
|
||||
skipReplayMarking bool // suppress JTI Set (token_manager.go:108-143); blacklist Get stays active
|
||||
}
|
||||
```
|
||||
|
||||
Both the type and field are unexported (internal-only knob). Signature change: `verifyToken(token string)` becomes `verifyToken(token string, opts verifyOpts)`. Existing callers pass `verifyOpts{}` (zero value = current behaviour). Bearer path passes `verifyOpts{skipReplayMarking: true}`.
|
||||
|
||||
**Critical semantics — must be reflected in implementation and tests:**
|
||||
- `skipReplayMarking` only gates the **Set** at `token_manager.go:108-143` (the call adding the JTI to the blacklist and replay cache).
|
||||
- The blacklist **Get** at `token_manager.go:44-47, 80-89` stays unconditionally active on the bearer path. Tokens revoked via `RevokeToken` (which adds the JTI to the blacklist) MUST still be rejected on the bearer path.
|
||||
- Must NOT be implemented by mutating `t.disableReplayDetection` (struct field) — that would create a cross-request race that disables replay protection globally.
|
||||
|
||||
A targeted regression test exercises: bearer token verified once → admin calls `RevokeToken` adding the JTI to the blacklist → same token replayed → 401.
|
||||
|
||||
### 7.8 Config additions (modify `settings.go`)
|
||||
|
||||
```go
|
||||
EnableBearerAuth bool `json:"enableBearerAuth,omitempty"`
|
||||
BearerIdentifierClaim string `json:"bearerIdentifierClaim,omitempty"`
|
||||
StripAuthorizationHeader bool `json:"stripAuthorizationHeader,omitempty"`
|
||||
BearerEmitWWWAuthenticate bool `json:"bearerEmitWWWAuthenticate,omitempty"`
|
||||
BearerOverridesCookie bool `json:"bearerOverridesCookie,omitempty"`
|
||||
MaxTokenAgeSeconds int64 `json:"maxTokenAgeSeconds,omitempty"`
|
||||
MaxIdentifierLength int `json:"maxIdentifierLength,omitempty"`
|
||||
BearerFailureThreshold int `json:"bearerFailureThreshold,omitempty"`
|
||||
BearerFailureWindowSeconds int `json:"bearerFailureWindowSeconds,omitempty"`
|
||||
BearerFailurePenaltySeconds int `json:"bearerFailurePenaltySeconds,omitempty"`
|
||||
```
|
||||
|
||||
Defaults (applied in `CreateConfig` for the bearer-related fields; values >0 only honoured when `EnableBearerAuth=true`):
|
||||
- `EnableBearerAuth`: `false`.
|
||||
- `BearerIdentifierClaim`: `"sub"`.
|
||||
- `StripAuthorizationHeader`: `true`.
|
||||
- `BearerEmitWWWAuthenticate`: `true` (RFC 6750 hint enabled by default; flip to false if recon-exposure is a concern).
|
||||
- `BearerOverridesCookie`: `false` (cookie wins when both present; flip to `true` for the legacy/industry-default behaviour).
|
||||
- `MaxTokenAgeSeconds`: `86400` (24h upper bound on `iat`).
|
||||
- `MaxIdentifierLength`: `256`.
|
||||
- `BearerFailureThreshold`: `20` (consecutive 401s per IP before throttle).
|
||||
- `BearerFailureWindowSeconds`: `60`.
|
||||
- `BearerFailurePenaltySeconds`: `60` (429 reply for this long after threshold tripped).
|
||||
|
||||
### 7.9 Startup validation (modify `main.go` `New()`)
|
||||
|
||||
- `EnableBearerAuth && Audience == ""` → fatal error.
|
||||
- `EnableBearerAuth && !StrictAudienceValidation` → warning log (recommended hardening).
|
||||
- `EnableBearerAuth && BearerIdentifierClaim == "email"` → fatal error (the bearer path is M2M and an `email` identifier without `email_verified` enforcement is a spoofing vector; default `BearerIdentifierClaim=sub` avoids this; explicit override to `email` is rejected).
|
||||
- `EnableBearerAuth && MaxTokenAgeSeconds <= 0` → reset to default 86400 with info log.
|
||||
- `EnableBearerAuth && BearerFailureThreshold <= 0` → reset to default 20 with info log.
|
||||
|
||||
## 8. Data Flow
|
||||
|
||||
### 8.1 Bearer path
|
||||
|
||||
```
|
||||
ServeHTTP entry (pre-init paths unchanged: logout, backchannel, frontchannel, excluded URLs, SSE/WS bypass)
|
||||
│
|
||||
├─ enableBearerAuth == false? → fall through to cookie path
|
||||
│
|
||||
└─ enableBearerAuth == true AND Authorization starts with "Bearer "
|
||||
│
|
||||
▼
|
||||
handleBearerRequest
|
||||
│
|
||||
├─ format guards (empty, length, segment count)
|
||||
│
|
||||
▼
|
||||
verifyToken(token, verifyOpts{SkipReplayMarking: true})
|
||||
│ signature, issuer, audience (strict), exp
|
||||
│
|
||||
▼
|
||||
classifyToken(claims) → reject ID tokens
|
||||
│
|
||||
▼
|
||||
if requireTokenIntrospection: introspectToken → active check
|
||||
│
|
||||
▼
|
||||
resolveIdentifier(claims) → sanitiseIdentifier
|
||||
│
|
||||
▼
|
||||
principal{Source: sourceBearer, …}
|
||||
│
|
||||
▼
|
||||
forwardAuthorized(rw, req, principal)
|
||||
│
|
||||
├─ roles/groups gate (403 on deny)
|
||||
├─ header injection
|
||||
├─ header templates
|
||||
├─ security headers
|
||||
├─ strip OIDC cookies (existing)
|
||||
├─ strip Authorization header (new, when configured)
|
||||
└─ next.ServeHTTP(rw, req)
|
||||
```
|
||||
|
||||
### 8.2 Cookie path (refactored, semantically unchanged)
|
||||
|
||||
```
|
||||
processAuthorizedRequest
|
||||
1. Session validity / backchannel-logout invalidation (unchanged).
|
||||
2. principal := buildPrincipalFromSession(session).
|
||||
3. forwardAuthorized(rw, req, principal).
|
||||
4. if session.IsDirty(): session.Save().
|
||||
```
|
||||
|
||||
## 9. Error Handling
|
||||
|
||||
| Trigger | Status | Body | WWW-Authenticate | Debug log reason |
|
||||
|---|---|---|---|---|
|
||||
| Empty bearer after prefix | 401 | `Unauthorized` | `Bearer error="invalid_request"` | empty bearer token |
|
||||
| Token over MaxLength | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token exceeds max length |
|
||||
| Not a 3-segment JWT | 401 | `Unauthorized` | `Bearer error="invalid_token"` | malformed JWT |
|
||||
| Disallowed `alg` (e.g. none, HS*) | 401 | `Unauthorized` | `Bearer error="invalid_token"` | unsupported alg |
|
||||
| Missing/oversized/bad-charset `kid` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | invalid kid |
|
||||
| Signature / issuer / aud / exp fail | 401 | `Unauthorized` | `Bearer error="invalid_token"` | reason from verifyToken (category only) |
|
||||
| `iat` older than MaxTokenAgeSeconds | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token too old (iat outside age bound) |
|
||||
| Multi-aud without matching `azp` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | multi-aud token without azp match |
|
||||
| Detected as ID token | 401 | `Unauthorized` | `Bearer error="invalid_token"` | ID tokens not accepted on bearer path |
|
||||
| JTI blacklisted (revoked) | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token JTI in blacklist |
|
||||
| Introspection `active=false` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token inactive at IdP |
|
||||
| Introspection endpoint failure | 503 | `Service Unavailable` | (none) | introspection unavailable |
|
||||
| Identifier claim missing/empty | 401 | `Unauthorized` | `Bearer error="invalid_token"` | no identifier claim |
|
||||
| Identifier fails sanitisation | 401 | `Unauthorized` | `Bearer error="invalid_token"` | invalid identifier characters |
|
||||
| Per-IP failure threshold tripped | 429 | `Too Many Requests` | (none); `Retry-After: <BearerFailurePenaltySeconds>` | source IP in penalty box |
|
||||
| Roles/groups not allowed | 403 | `Access denied` | (none) | user not in allowedRolesAndGroups |
|
||||
|
||||
Responses never include token contents, never include the raw failure reason, and never set `Location` headers (API clients cannot follow redirects).
|
||||
|
||||
## 10. Edge Cases
|
||||
|
||||
1. **Both bearer header and cookie session present.** Cookie wins by default (safer against browser/extension/proxy bearer injection). `BearerOverridesCookie=true` flips to bearer-wins. Either way: WARN log includes both source markers so operators can audit.
|
||||
2. **`Authorization: Basic …`.** Not bearer; cookie path runs as today.
|
||||
3. **`Authorization: Bearer ` (trailing space, no value).** Empty after trim → 401.
|
||||
4. **Mixed-case prefix (`bearer`, `BEARER`, `BeArEr`).** Case-insensitive prefix check; token value preserved verbatim.
|
||||
5. **Multiple `Authorization` headers.** Use only the first (Go `http.Header.Get` default). Documented.
|
||||
6. **Bearer during OIDC init wait.** Bearer requests also block on init: we need `issuerURL`, `audience`, JWKs ready. If init fails, bearer requests return 503 just like cookie requests.
|
||||
7. **SSE / WebSocket bypass with bearer.** Bypass paths keep cookie-only behaviour. Operators who want bearer on streaming endpoints must remove SSE/WS bypass. Documented.
|
||||
8. **Logout endpoint with bearer.** Logout runs before bearer detection. Treated as cookie-session logout; bearer token revocation requires IdP-side action.
|
||||
9. **Excluded URLs with bearer.** Bypass excluded URLs as today; bearer not validated on excluded paths. ADDITIONALLY: `Authorization: Bearer` is stripped from the request before forwarding so the token can't leak into the excluded endpoint's downstream logs / metrics scrapers / health checks.
|
||||
10. **Concurrent identical bearer requests.** Existing `tokenCache` is concurrency-safe; no new locking.
|
||||
11. **Client rotates token between requests.** Independent verification per token; independent cache entries.
|
||||
12. **Clock skew.** Use existing `jwt.Verify` leeway. (If absent, add ±30s as a separate change; out of scope here.)
|
||||
|
||||
## 11. Testing Strategy
|
||||
|
||||
### 11.1 Integration tests (new `bearer_auth_test.go`)
|
||||
|
||||
Table-driven test against a real `httptest.Server` and the full `ServeHTTP` flow. Coverage matrix:
|
||||
|
||||
- Valid access token + allowed roles → 200, `next` ran, `X-Forwarded-User` set.
|
||||
- Valid token without configured roles → 200.
|
||||
- Wrong audience, expired, tampered signature → 401, `next` did not run.
|
||||
- ID token presented → 401 (`ID tokens not accepted`).
|
||||
- Malformed JWT (2 segments) → 401.
|
||||
- Oversized token (> MaxLength) → 401.
|
||||
- Empty bearer → 401.
|
||||
- Missing identifier claim → 401.
|
||||
- Identifier containing `\r\n` → 401.
|
||||
- `allowedRolesAndGroups` mismatch → 403.
|
||||
- `allowedRolesAndGroups` match → 200.
|
||||
- `EnableBearerAuth=false` + bearer header → cookie path runs (302 to `/authorize`).
|
||||
- Bearer + valid cookie session → bearer wins, 200.
|
||||
- `StripAuthorizationHeader=true` → downstream sees no `Authorization`.
|
||||
- `StripAuthorizationHeader=false` → downstream sees `Authorization`.
|
||||
- Case variants (`bearer`, `BEARER`) → 200.
|
||||
- SSE bypass + bearer → cookie-only check applies (bearer ignored).
|
||||
- **Replay regression**: same token 1000 times in a row → all 200.
|
||||
- **Cache-evict regression**: same token, force-evict `tokenCache` between iterations (call `tokenCache.Delete` directly), replay → still 200 (verifies `skipReplayMarking` doesn't poison the blacklist).
|
||||
- **Revocation-while-bearer regression**: bearer token verified once → admin calls `RevokeToken` adding JTI to blacklist → same token presented → 401 (verifies blacklist Get stays active on bearer path even with `skipReplayMarking` set).
|
||||
- **Alg-pin: token signed with `alg=none`** → 401, no JWKS fetch happens (verify with a counting mock).
|
||||
- **`kid` injection: 50KB random kid** → 401 immediately, no JWKS fetch.
|
||||
- **Per-IP throttle**: 21 bad bearer requests from same IP within 1 minute → 22nd returns 429 + Retry-After.
|
||||
- **`iat` upper-age**: token with `iat = now - 25h` → 401 (older than 24h default).
|
||||
- **Multi-aud without azp**: aud = `["a", "b"]`, no azp → 401.
|
||||
- **Multi-aud with matching azp**: aud = `["api-aud", "other"]`, azp = clientID → 200.
|
||||
- **Identifier with bidi-override**: sub contains U+202E → 401.
|
||||
- **Identifier with comma**: sub = `"alice,bob"` → 401.
|
||||
- **Identifier over 256 bytes** → 401.
|
||||
- **`UserIdentifierClaim=email` at startup with EnableBearerAuth=true** → startup fails.
|
||||
- **Excluded URL + bearer**: bearer header presented on excluded URL → request forwarded, downstream sees no `Authorization` header (stripped).
|
||||
|
||||
### 11.2 Unit tests (in `bearer_auth_test.go`)
|
||||
|
||||
- `classifyToken`: ID-token detection, access-token detection by `scope`/`scp`/`token_use`, ambiguous → reject.
|
||||
- `resolveIdentifier`: precedence (`userIdentifierClaim` → `sub` → `client_id`/`azp`); missing → error; empty string → error.
|
||||
- `sanitizeIdentifier`: rejects all `unicode.IsControl`; accepts email/sub-style values.
|
||||
|
||||
### 11.3 Introspection tests (`bearer_auth_introspection_test.go`)
|
||||
|
||||
- Token valid + introspection `active=true` → 200.
|
||||
- Token valid + introspection `active=false` → 401.
|
||||
- Introspection endpoint 500 → 503.
|
||||
- Second request hits introspection cache (no second HTTP call).
|
||||
|
||||
### 11.4 Startup validation tests (extend `settings_test.go` / `main_test.go`)
|
||||
|
||||
- `EnableBearerAuth=true, Audience=""` → `New()` errors.
|
||||
- `EnableBearerAuth=true, StrictAudienceValidation=false` → succeeds with warning.
|
||||
- `EnableBearerAuth=false` → no validation; existing tests untouched.
|
||||
|
||||
### 11.5 Cookie-path regression suite
|
||||
|
||||
- All existing `TestServeHTTP_*` tests in `main_servehttp_test.go` pass unmodified.
|
||||
- Add: cookie session, `EnableBearerAuth=true`, no bearer header → identical behaviour to baseline.
|
||||
- Add: dirty session still triggers `Save()` after refactor.
|
||||
|
||||
### 11.6 Principal invariants
|
||||
|
||||
- `buildPrincipalFromSession`: `Source == sourceSession`; `IDToken` / `RefreshToken` populated when present in session.
|
||||
- `buildPrincipalFromBearerToken`: `Source == sourceBearer`; `IDToken == ""`, `RefreshToken == ""`.
|
||||
- `forwardAuthorized` produces identical headers for equivalent principals regardless of source.
|
||||
|
||||
### 11.7 Coverage gate
|
||||
|
||||
- New code in `bearer_auth.go` and `principal.go`: ≥ 90% line coverage.
|
||||
- `forwardAuthorized` coverage ≥ existing `processAuthorizedRequest` coverage baseline.
|
||||
|
||||
### 11.8 Out of scope (follow-ups)
|
||||
|
||||
- Load test of bearer vs cookie hot path.
|
||||
- Fuzzing the JWT parser.
|
||||
- Additional auth methods (mTLS, API keys) — design enables them, but they are separate work.
|
||||
|
||||
## 12. Migration / Rollout
|
||||
|
||||
Default-off. Existing deployments observe no behavioural change. Operators opt in by setting:
|
||||
|
||||
```yaml
|
||||
enableBearerAuth: true
|
||||
audience: https://api.example.com # required when bearer enabled
|
||||
# optional:
|
||||
stripAuthorizationHeader: true # default
|
||||
requireTokenIntrospection: false # default; set true for real-time revocation
|
||||
userIdentifierClaim: client_id # optional override; defaults to sub fallback chain
|
||||
```
|
||||
|
||||
Documentation: update `docs/CONFIGURATION.md` with a bearer-auth section, and add a new `docs/BEARER_AUTH.md` covering the security model, threat assumptions (token issuer is trusted; audience must be set; bearer means trust the issuer's revocation policy unless introspection enabled), and recommended configurations for common IdPs.
|
||||
|
||||
## 13. Security Considerations
|
||||
|
||||
| Concern | Mitigation |
|
||||
|---|---|
|
||||
| Token confusion (ID token used as bearer) | Reuse `detectTokenType` (`token_manager.go:187-303`) which checks `nonce`, `typ: at+jwt`, `token_use`, `scope`, aud-vs-clientID. Belt-and-braces: explicit `nonce` + `token_use == "id"` rejection on top. |
|
||||
| Audience confusion (token for service B accepted by A) | `Audience` mandatory at startup; verified via existing `VerifyJWTSignatureAndClaims`; multi-aud tokens require matching `azp == clientID`. |
|
||||
| Replay-via-blacklist false positive | `verifyOpts{skipReplayMarking: true}` on bearer path. Gates ONLY the Set; the Get stays so revoked tokens still fail. |
|
||||
| Revocation lag | Optional RFC 7662 introspection. Bearer-path introspection cache TTL capped at 60s. Set `RequireTokenIntrospection=true` for real-time revocation. |
|
||||
| `alg`-confusion / `alg=none` attacks | Hard-pin asymmetric allowlist at bearer entry, **before** JWKS fetch. Prevents wasted upstream calls and locks out HS/none probes. |
|
||||
| `kid` injection / JWKS amplification | `kid` length cap (256 bytes) + charset allowlist enforced at bearer entry. |
|
||||
| Bearer 401 brute-force / oracle | Per-IP `failedBearerAttempts` cache; configurable threshold + penalty box returning 429 + `Retry-After`. |
|
||||
| `iat` clock-manipulation / forever-tokens | `MaxTokenAgeSeconds` upper bound (default 24h); cookie path unchanged. |
|
||||
| Identifier-driven header injection | `sanitizeIdentifier`: length cap, control-char + bidi-override + `,;=` rejection. `net/http` rejects CRLF on the wire too (defence in depth). |
|
||||
| Token leakage downstream | `StripAuthorizationHeader=true` by default. Also: `Authorization` stripped on excluded-URL requests so bearer can't leak into health/metrics downstream logs. |
|
||||
| Token-in-logs | All log paths log reason categories, not raw tokens. Identifier hashed via SHA-256 truncated to 8 hex chars before any info/warn-level emission (full identifier only at debug). New `safeLogAuthEvent(category, hashedIdentifier, reasonCode)` helper makes this hard to misuse. |
|
||||
| `email` claim spoofing | Startup fails if `EnableBearerAuth && UserIdentifierClaim == "email"`. Future human-user bearer iteration must add `email_verified` enforcement. |
|
||||
| Bypass on SSE / WS endpoints | SSE/WS bypass keeps cookie-only behaviour; bearer ignored. Operators choose to widen if needed. |
|
||||
| Mixed bearer + cookie precedence | Cookie wins by default (safer for browser scenarios); `BearerOverridesCookie=true` flips. WARN log on both-present requests. |
|
||||
| Configuration drift (operator forgets audience) | Startup fails when `EnableBearerAuth=true && Audience==""`. |
|
||||
| Downstream blast radius when `StripAuthorizationHeader=false` | Documented: forwarded bearer extends token's blast radius to all downstream services. Logs at those services become token stores. Operators must treat downstream log policy accordingly. |
|
||||
| Introspection auth method (pre-existing gap, called out) | `token_introspection.go:80` uses `client_secret_basic` only; does not honour `private_key_jwt`. Out of scope for this PR but documented as a follow-up; operators using `ClientAuthMethod=private_key_jwt` + `RequireTokenIntrospection=true` should be aware introspection will use basic auth. |
|
||||
|
||||
## 14. Open Questions
|
||||
|
||||
None — all design decisions resolved during brainstorming + security review. Implementation may surface incidental questions (e.g. exact clock-skew leeway in `jwt.Verify`); those are out of scope for this spec and handled in the implementation plan.
|
||||
|
||||
## 14a. Security Review Reference
|
||||
|
||||
This design was reviewed by the `security-reviewer` subagent on 2026-05-18. Findings incorporated:
|
||||
|
||||
- **Critical**: C1 (classifier reuses `detectTokenType`), C2 (sub fallback dropped — unreachable due to `jwt.go:416`), C3 (replay-marking gates only Set, not Get; revocation regression test added).
|
||||
- **High**: H1 (alg pinned at bearer entry), H2 (kid length + charset), H3 (cookie wins by default, configurable), H4 (per-IP 401 throttle), H5 (multi-aud requires azp).
|
||||
- **Medium**: M1 (identifier max-length + bidi reject + delimiter chars), M2 (introspection cache TTL capped at 60s on bearer path), M4 (log-hashing via SHA-256[:8]), M5 (StripAuth blast-radius documented), M6 (iat upper-age bound), M7 (Authorization stripped on excluded URLs).
|
||||
- **Low/Nit**: L2 (renamed to `BearerEmitWWWAuthenticate`), N3 (startup rejects `UserIdentifierClaim=email`).
|
||||
- **Documented as pre-existing gaps (follow-up PRs)**: M3 (introspection auth method doesn't honour `private_key_jwt`).
|
||||
|
||||
## 15. Implementation Plan Reference
|
||||
|
||||
To be produced by the `writing-plans` skill in a follow-up document at `docs/superpowers/plans/2026-05-18-bearer-token-auth-plan.md`. The plan decomposes this design into ordered, independently-testable PRs.
|
||||
@@ -370,21 +370,6 @@ func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, res
|
||||
return r.saveCredentials(resp)
|
||||
}
|
||||
|
||||
// deleteCredentialsFromStore removes credentials from the configured storage backend
|
||||
// Falls back to legacy file-based deletion if no store is configured
|
||||
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Delete(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based deletion
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
@@ -423,187 +408,3 @@ func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse,
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateClientRegistration updates an existing client registration using RFC 7592
|
||||
// This requires the registration_client_uri and registration_access_token from the original registration
|
||||
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to update")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Build update request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build update request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create update request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read update response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse update response: %w", err)
|
||||
}
|
||||
|
||||
// Update cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// ReadClientRegistration reads the current client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to read")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create read request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse read response: %w", err)
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// DeleteClientRegistration deletes the client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return fmt.Errorf("no existing registration to delete")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create delete request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle error responses (204 No Content is success)
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials from storage if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.deleteCredentialsFromStore(ctx); err != nil {
|
||||
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("Successfully deleted client registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -735,258 +735,6 @@ func TestDCRConfigDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateClientRegistration tests the RFC 7592 client update functionality
|
||||
func TestUpdateClientRegistration(t *testing.T) {
|
||||
updateCalled := false
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPut {
|
||||
updateCalled = true
|
||||
|
||||
// Verify authorization header
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
t.Error("Missing Authorization header for update")
|
||||
}
|
||||
|
||||
resp := ClientRegistrationResponse{
|
||||
ClientID: "updated-client-id",
|
||||
ClientSecret: "updated-client-secret",
|
||||
RegistrationAccessToken: "new-access-token",
|
||||
RegistrationClientURI: r.URL.String(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
ClientMetadata: &ClientRegistrationMetadata{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
},
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "original-client-id",
|
||||
ClientSecret: "original-client-secret",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Perform update
|
||||
ctx := context.Background()
|
||||
resp, err := registrar.UpdateClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
if !updateCalled {
|
||||
t.Error("Update endpoint was not called")
|
||||
}
|
||||
|
||||
if resp.ClientID != "updated-client-id" {
|
||||
t.Errorf("Updated ClientID mismatch: got %s", resp.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteClientRegistration tests the RFC 7592 client deletion functionality
|
||||
func TestDeleteClientRegistration(t *testing.T) {
|
||||
deleteCalled := false
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodDelete {
|
||||
deleteCalled = true
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
credentialsFile := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
// Create a credentials file to test deletion
|
||||
os.WriteFile(credentialsFile, []byte(`{"client_id":"test"}`), 0600)
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
CredentialsFile: credentialsFile,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Perform delete
|
||||
ctx := context.Background()
|
||||
err := registrar.DeleteClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if !deleteCalled {
|
||||
t.Error("Delete endpoint was not called")
|
||||
}
|
||||
|
||||
// Verify cache is cleared
|
||||
if registrar.GetCachedResponse() != nil {
|
||||
t.Error("Cached response should be cleared after deletion")
|
||||
}
|
||||
|
||||
// Verify credentials file is deleted
|
||||
if _, err := os.Stat(credentialsFile); !os.IsNotExist(err) {
|
||||
t.Error("Credentials file should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadClientRegistration tests the RFC 7592 client read functionality
|
||||
func TestReadClientRegistration(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet {
|
||||
resp := ClientRegistrationResponse{
|
||||
ClientID: "read-client-id",
|
||||
ClientSecret: "read-client-secret",
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ResponseTypes: []string{"code"},
|
||||
GrantTypes: []string{"authorization_code"},
|
||||
ApplicationType: "web",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "original-client-id",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Read registration
|
||||
ctx := context.Background()
|
||||
resp, err := registrar.ReadClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.ClientID != "read-client-id" {
|
||||
t.Errorf("Read ClientID mismatch: got %s", resp.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOperationsWithoutCachedResponse tests error handling when no cached response exists
|
||||
func TestOperationsWithoutCachedResponse(t *testing.T) {
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
&http.Client{},
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
"https://example.com",
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Update without cached response
|
||||
_, err := registrar.UpdateClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Update should fail without cached response: %v", err)
|
||||
}
|
||||
|
||||
// Test Read without cached response
|
||||
_, err = registrar.ReadClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Read should fail without cached response: %v", err)
|
||||
}
|
||||
|
||||
// Test Delete without cached response
|
||||
err = registrar.DeleteClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Delete should fail without cached response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOperationsWithoutManagementCredentials tests error handling without management URIs
|
||||
func TestOperationsWithoutManagementCredentials(t *testing.T) {
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
&http.Client{},
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
"https://example.com",
|
||||
)
|
||||
|
||||
// Set up cached response WITHOUT management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
// Missing RegistrationAccessToken and RegistrationClientURI
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Update without management credentials
|
||||
_, err := registrar.UpdateClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Update should fail without management credentials: %v", err)
|
||||
}
|
||||
|
||||
// Test Read without management credentials
|
||||
_, err = registrar.ReadClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Read should fail without management credentials: %v", err)
|
||||
}
|
||||
|
||||
// Test Delete without management credentials
|
||||
err = registrar.DeleteClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Delete should fail without management credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stringContains is a helper function to check if a string contains a substring
|
||||
func stringContains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContainsHelper(s, substr))
|
||||
|
||||
@@ -2,6 +2,8 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -40,6 +42,31 @@ func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, http
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
jwks, err := m.GetJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if jwks == nil {
|
||||
return nil, fmt.Errorf("JWKS is nil")
|
||||
}
|
||||
for i := range jwks.Keys {
|
||||
k := &jwks.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Cleanup() {
|
||||
atomic.AddInt32(&m.CleanupCalls, 1)
|
||||
m.mu.Lock()
|
||||
|
||||
+7
-6
@@ -539,10 +539,10 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
errStr := strings.ToLower(err.Error())
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
if contains(errStr, retryableErr) {
|
||||
if contains(errStr, strings.ToLower(retryableErr)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -551,7 +551,7 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
errStr := netErr.Error()
|
||||
errStr := strings.ToLower(netErr.Error())
|
||||
temporaryPatterns := []string{
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
@@ -859,8 +859,9 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f
|
||||
|
||||
// isServiceDegraded checks if a service is currently degraded
|
||||
func (gd *GracefulDegradation) isServiceDegraded(serviceName string) bool {
|
||||
gd.mutex.RLock()
|
||||
defer gd.mutex.RUnlock()
|
||||
// Uses a write lock because the recovery-timeout branch deletes from the map.
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
|
||||
degradedTime, exists := gd.degradedServices[serviceName]
|
||||
if !exists {
|
||||
@@ -954,7 +955,7 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
|
||||
gd.mutex.RLock()
|
||||
defer gd.mutex.RUnlock()
|
||||
|
||||
var degraded []string
|
||||
degraded := make([]string, 0, len(gd.degradedServices))
|
||||
for serviceName := range gd.degradedServices {
|
||||
degraded = append(degraded, serviceName)
|
||||
}
|
||||
|
||||
@@ -101,6 +101,16 @@ http:
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Optional: switch to RFC 7523 private_key_jwt client auth
|
||||
# (Entra ID, Okta, Auth0, Keycloak). Replaces clientSecret with a
|
||||
# signed JWT assertion. See README for details and PEM formats.
|
||||
# ----------------------------------------------------------------
|
||||
# clientAuthMethod: "private_key_jwt"
|
||||
# clientAssertionKeyPath: "/etc/traefik/oidc/client-key.pem"
|
||||
# clientAssertionKeyID: "prod-key-2026"
|
||||
# clientAssertionAlg: "RS256" # or PS256/384/512, ES256/384/512
|
||||
|
||||
# Session Configuration
|
||||
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
|
||||
sessionMaxAge: 28800 # 8 hours
|
||||
|
||||
@@ -4,8 +4,8 @@ go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
|
||||
@@ -12,12 +12,12 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3 h1:xoDtBqeZGmXj7IteiE1M5WMuzeoqag58qEleI0Cf2Ms=
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3/go.mod h1:+Cn78qZo8rc3T9eZt0v3oICYRdd75wORtSidc8lNjDQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
|
||||
+62
-5
@@ -17,6 +17,21 @@ import (
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// newUUIDv4 returns an RFC 4122 v4 UUID string (e.g.
|
||||
// "f47ac10b-58cc-4372-a567-0e02b2c3d479") backed by crypto/rand. Used for CSRF
|
||||
// tokens and other opaque random identifiers — replaces github.com/google/uuid
|
||||
// to keep the plugin stdlib-only on the production path.
|
||||
func newUUIDv4() (string, error) {
|
||||
var b [16]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return "", fmt.Errorf("could not generate UUID: %w", err)
|
||||
}
|
||||
b[6] = (b[6] & 0x0f) | 0x40 // version 4
|
||||
b[8] = (b[8] & 0x3f) | 0x80 // RFC 4122 variant
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]), nil
|
||||
}
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
|
||||
// The nonce is used to prevent replay attacks and associate client sessions with ID tokens.
|
||||
// Returns:
|
||||
@@ -92,9 +107,12 @@ type TokenResponse struct {
|
||||
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant)
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
"client_secret": {t.clientSecret},
|
||||
"grant_type": {grantType},
|
||||
}
|
||||
// client_id is sent in the body for every method except client_secret_basic,
|
||||
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
|
||||
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
|
||||
data.Set("client_id", t.clientID)
|
||||
}
|
||||
|
||||
if grantType == "authorization_code" {
|
||||
@@ -126,16 +144,33 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
}
|
||||
}
|
||||
|
||||
// Read tokenURL with RLock
|
||||
// Read tokenURL with RLock — needed as audience for private_key_jwt (RFC 7523 §3).
|
||||
t.metadataMu.RLock()
|
||||
tokenURL := t.tokenURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
useBasicAuth := false
|
||||
if t.clientAssertion != nil {
|
||||
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign client assertion: %w", err)
|
||||
}
|
||||
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
|
||||
data.Set("client_assertion", assertion)
|
||||
} else if t.clientAuthMethod == "client_secret_basic" {
|
||||
useBasicAuth = true
|
||||
} else {
|
||||
data.Set("client_secret", t.clientSecret)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if useBasicAuth {
|
||||
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
@@ -357,10 +392,19 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
// localRedirect is used when there is no provider end-session endpoint and
|
||||
// the plugin redirects the browser itself. It must never be an absolute URL
|
||||
// derived from the request host (X-Forwarded-Host is client-controllable and
|
||||
// would be an open redirect); use a host-relative path, or the operator's
|
||||
// own configured absolute URL, instead.
|
||||
localRedirect := "/"
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
||||
localRedirect = normalizeLogoutPath(postLogoutRedirectURI)
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
} else {
|
||||
localRedirect = postLogoutRedirectURI
|
||||
}
|
||||
|
||||
// Read endSessionURL with RLock
|
||||
@@ -379,7 +423,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
http.Redirect(rw, req, localRedirect, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs a logout URL for the OIDC provider's end session endpoint.
|
||||
@@ -408,6 +452,19 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// setOAuthBasicAuth sets the Authorization header per RFC 6749 §2.3.1: the
|
||||
// client_id and client_secret are form-urlencoded individually, joined with a
|
||||
// colon, then base64-encoded. This differs from http.Request.SetBasicAuth,
|
||||
// which skips the form-urlencode step — that matters for credentials with
|
||||
// reserved characters (`:`, `@`, `+`, `%`, etc.) where the wire format would
|
||||
// otherwise diverge from what the spec mandates.
|
||||
func setOAuthBasicAuth(req *http.Request, clientID, clientSecret string) {
|
||||
user := url.QueryEscape(clientID)
|
||||
pass := url.QueryEscape(clientSecret)
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
}
|
||||
|
||||
// deduplicateScopes removes duplicate scopes from a slice while preserving order.
|
||||
// This ensures that OAuth scope parameters don't contain duplicates which could
|
||||
// cause issues with some authorization servers.
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewUUIDv4 verifies the in-house UUID v4 generator produces RFC 4122
|
||||
// compliant identifiers. Locks in the replacement for github.com/google/uuid
|
||||
// — a regression here would weaken the CSRF token used in the OIDC flow.
|
||||
func TestNewUUIDv4(t *testing.T) {
|
||||
rfc4122v4 := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
|
||||
|
||||
const samples = 1000
|
||||
seen := make(map[string]struct{}, samples)
|
||||
for i := 0; i < samples; i++ {
|
||||
got, err := newUUIDv4()
|
||||
if err != nil {
|
||||
t.Fatalf("newUUIDv4 failed: %v", err)
|
||||
}
|
||||
if !rfc4122v4.MatchString(got) {
|
||||
t.Fatalf("UUID %q does not match RFC 4122 v4 format", got)
|
||||
}
|
||||
if _, dup := seen[got]; dup {
|
||||
t.Fatalf("duplicate UUID emitted within %d samples: %q", samples, got)
|
||||
}
|
||||
seen[got] = struct{}{}
|
||||
}
|
||||
}
|
||||
+13
-5
@@ -3,6 +3,7 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -25,10 +26,16 @@ type HTTPClientConfig struct {
|
||||
Timeout time.Duration
|
||||
MaxConnsPerHost int
|
||||
WriteBufferSize int
|
||||
UseCookieJar bool
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
// RootCAs is an optional certificate pool used for TLS verification.
|
||||
// A nil pool means "use the system trust store" (default behavior).
|
||||
RootCAs *x509.CertPool
|
||||
// InsecureSkipVerify disables TLS certificate verification.
|
||||
// ONLY set this for local development against self-signed certificates.
|
||||
InsecureSkipVerify bool
|
||||
UseCookieJar bool
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
}
|
||||
|
||||
// DefaultHTTPClientConfig returns the default configuration for general use
|
||||
@@ -203,7 +210,8 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
PreferServerCipherSuites: true,
|
||||
InsecureSkipVerify: false, // Always verify certificates
|
||||
RootCAs: config.RootCAs,
|
||||
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
|
||||
},
|
||||
ForceAttemptHTTP2: config.ForceHTTP2,
|
||||
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
|
||||
|
||||
+48
-9
@@ -3,6 +3,7 @@ package traefikoidc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
@@ -25,6 +26,10 @@ type sharedTransport struct {
|
||||
lastUsed time.Time
|
||||
transport *http.Transport
|
||||
refCount int
|
||||
// tlsKey identifies the TLS trust settings (CA pool + InsecureSkipVerify)
|
||||
// this transport was built with, so the at-limit fallback only reuses a
|
||||
// transport whose TLS configuration matches the caller's.
|
||||
tlsKey string
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -52,19 +57,26 @@ func GetGlobalTransportPool() *SharedTransportPool {
|
||||
|
||||
// GetOrCreateTransport gets or creates a shared transport with the given config
|
||||
func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *http.Transport {
|
||||
// SECURITY FIX: Check client limit before creating new transport
|
||||
// SECURITY FIX: Check client limit before creating new transport.
|
||||
if atomic.LoadInt32(&p.clientCount) >= p.maxClients {
|
||||
// Return existing transport if limit reached
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
// At the client limit: only reuse a transport that was built for the
|
||||
// SAME config (same TLS trust store). refCount is mutated under the
|
||||
// write lock to avoid a data race, and a transport created for a
|
||||
// different configuration is never handed back — doing so could apply
|
||||
// the wrong (possibly verification-disabled) TLS settings to a request.
|
||||
want := tlsConfigKey(config)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for _, shared := range p.transports {
|
||||
if shared != nil && shared.transport != nil {
|
||||
if shared != nil && shared.transport != nil && shared.tlsKey == want {
|
||||
shared.refCount++
|
||||
shared.lastUsed = time.Now()
|
||||
return shared.transport
|
||||
}
|
||||
}
|
||||
// If no transport available, return nil (caller should handle)
|
||||
// No TLS-compatible transport available; return nil so the caller falls
|
||||
// back to a default, certificate-verifying transport rather than one
|
||||
// with a different (possibly verification-disabled) trust store.
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -103,7 +115,8 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
PreferServerCipherSuites: true,
|
||||
InsecureSkipVerify: false,
|
||||
RootCAs: config.RootCAs,
|
||||
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
|
||||
},
|
||||
ForceAttemptHTTP2: config.ForceHTTP2,
|
||||
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
|
||||
@@ -123,6 +136,7 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
|
||||
transport: transport,
|
||||
refCount: 1,
|
||||
lastUsed: time.Now(),
|
||||
tlsKey: tlsConfigKey(config),
|
||||
}
|
||||
|
||||
return transport
|
||||
@@ -205,8 +219,33 @@ func (p *SharedTransportPool) performCleanup() {
|
||||
|
||||
// configKey generates a unique key for a config
|
||||
func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
|
||||
// Simple key based on main parameters
|
||||
return string(rune(config.MaxConnsPerHost)) + string(rune(config.MaxIdleConnsPerHost))
|
||||
// Pool transports by the parameters that change TLS or connection
|
||||
// behavior. RootCAs and InsecureSkipVerify MUST be part of the key:
|
||||
// otherwise a middleware configured with a custom CA would share a
|
||||
// transport with one using the system store, silently bypassing its
|
||||
// CA configuration.
|
||||
skip := "0"
|
||||
if config.InsecureSkipVerify {
|
||||
skip = "1"
|
||||
}
|
||||
return fmt.Sprintf("%d|%d|%p|%s",
|
||||
config.MaxConnsPerHost,
|
||||
config.MaxIdleConnsPerHost,
|
||||
config.RootCAs,
|
||||
skip,
|
||||
)
|
||||
}
|
||||
|
||||
// tlsConfigKey identifies only the TLS trust settings of a config — the CA pool
|
||||
// and the InsecureSkipVerify flag. Two configs with the same tlsConfigKey are
|
||||
// safe to serve from the same transport even if other (non-TLS) parameters such
|
||||
// as connection limits differ; configs with different TLS settings are not.
|
||||
func tlsConfigKey(config HTTPClientConfig) string {
|
||||
skip := "0"
|
||||
if config.InsecureSkipVerify {
|
||||
skip = "1"
|
||||
}
|
||||
return fmt.Sprintf("%p|%s", config.RootCAs, skip)
|
||||
}
|
||||
|
||||
// Cleanup closes all transports and stops the cleanup goroutine
|
||||
|
||||
+14
-27
@@ -10,6 +10,14 @@ import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Pre-compiled regex patterns for validation (const patterns should use MustCompile)
|
||||
var (
|
||||
emailRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
urlRegexPattern = regexp.MustCompile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
|
||||
tokenRegexPattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)
|
||||
usernameRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
|
||||
)
|
||||
|
||||
// InputValidator provides comprehensive input validation and sanitization
|
||||
// to protect against common security vulnerabilities including SQL injection,
|
||||
// XSS, path traversal, and other injection attacks. It validates and sanitizes
|
||||
@@ -73,7 +81,7 @@ func DefaultInputValidationConfig() InputValidationConfig {
|
||||
}
|
||||
|
||||
// NewInputValidator creates a new input validator with the specified configuration.
|
||||
// It compiles all necessary regex patterns and initializes security pattern lists.
|
||||
// It uses pre-compiled regex patterns and initializes security pattern lists.
|
||||
//
|
||||
// Parameters:
|
||||
// - config: Validation configuration with size limits and mode settings.
|
||||
@@ -81,29 +89,8 @@ func DefaultInputValidationConfig() InputValidationConfig {
|
||||
//
|
||||
// Returns:
|
||||
// - A configured InputValidator instance.
|
||||
// - An error if regex compilation fails.
|
||||
// - An error (always nil, kept for API compatibility).
|
||||
func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputValidator, error) {
|
||||
// Compile regex patterns
|
||||
emailRegex, err := regexp.Compile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile email regex: %w", err)
|
||||
}
|
||||
|
||||
urlRegex, err := regexp.Compile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile URL regex: %w", err)
|
||||
}
|
||||
|
||||
tokenRegex, err := regexp.Compile(`^[A-Za-z0-9._-]+$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile token regex: %w", err)
|
||||
}
|
||||
|
||||
usernameRegex, err := regexp.Compile(`^[a-zA-Z0-9._-]+$`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile username regex: %w", err)
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
@@ -112,10 +99,10 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
emailRegex: emailRegexPattern,
|
||||
urlRegex: urlRegexPattern,
|
||||
tokenRegex: tokenRegexPattern,
|
||||
usernameRegex: usernameRegexPattern,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
|
||||
Vendored
+3
@@ -24,6 +24,7 @@ type Config struct {
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
TLSServerName string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
@@ -34,6 +35,8 @@ type Config struct {
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
EnableMetrics bool
|
||||
EnableTLS bool
|
||||
TLSSkipVerify bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
|
||||
Vendored
+82
-35
@@ -20,6 +20,7 @@ type HybridBackend struct {
|
||||
ctx context.Context
|
||||
syncWriteCacheTypes map[string]bool
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
l1BackfillBuffer chan *l1BackfillItem
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
l1Hits atomic.Int64
|
||||
@@ -28,6 +29,7 @@ type HybridBackend struct {
|
||||
l1Writes atomic.Int64
|
||||
misses atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
l1BackfillDrops atomic.Int64
|
||||
fallbackMode atomic.Bool
|
||||
}
|
||||
|
||||
@@ -39,6 +41,15 @@ type asyncWriteItem struct {
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// l1BackfillItem represents a deferred write of an L2-resolved value back into
|
||||
// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot
|
||||
// detonate the goroutine count (issue: ~1000% CPU under sustained polling).
|
||||
type l1BackfillItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
@@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
@@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start L1 backfill worker (single goroutine) to bound goroutine growth on
|
||||
// L2 hits regardless of request rate.
|
||||
h.wg.Add(1)
|
||||
go h.l1BackfillWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
@@ -147,7 +164,7 @@ func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl t
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
|
||||
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", redactKey(key))
|
||||
return nil // Don't fail the operation if L2 is down
|
||||
}
|
||||
|
||||
@@ -159,13 +176,13 @@ func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl t
|
||||
// Synchronous write for critical cache types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
|
||||
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", redactKey(key), err)
|
||||
h.recordL2Error()
|
||||
// Don't fail the operation - L1 write succeeded
|
||||
return nil
|
||||
}
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
|
||||
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", redactKey(key))
|
||||
} else {
|
||||
// Asynchronous write for non-critical cache types
|
||||
select {
|
||||
@@ -175,10 +192,10 @@ func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl t
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
h.logger.Debugf("Queued async write to L2 for key: %s", key)
|
||||
h.logger.Debugf("Queued async write to L2 for key: %s", redactKey(key))
|
||||
default:
|
||||
// Buffer is full, log and continue
|
||||
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
|
||||
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", redactKey(key))
|
||||
h.errors.Add(1)
|
||||
}
|
||||
}
|
||||
@@ -192,7 +209,7 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
value, ttl, exists, err := h.primary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L1 get error for key %s: %v", key, err)
|
||||
h.logger.Debugf("L1 get error for key %s: %v", redactKey(key), err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
@@ -210,7 +227,7 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
value, ttl, exists, err = h.secondary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L2 get error for key %s: %v", key, err)
|
||||
h.logger.Debugf("L2 get error for key %s: %v", redactKey(key), err)
|
||||
h.recordL2Error()
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil // Don't propagate L2 errors
|
||||
@@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
// Populate L1 cache with value from L2 (write-through on read).
|
||||
// Hand off to the bounded backfill worker instead of spawning a goroutine
|
||||
// per read - under burst that would mint thousands of goroutines.
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
@@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error {
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
close(h.l1BackfillBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
@@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt
|
||||
return nil
|
||||
}
|
||||
|
||||
// queueL1Backfill enqueues an L2-resolved value for write-through into L1.
|
||||
// Drops on full buffer to keep the read path constant-time; the next L2 hit
|
||||
// for the same key simply re-queues it.
|
||||
func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) {
|
||||
select {
|
||||
case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}:
|
||||
default:
|
||||
h.l1BackfillDrops.Add(1)
|
||||
h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", redactKey(key))
|
||||
}
|
||||
}
|
||||
|
||||
// l1BackfillWorker drains the backfill queue serially. Single worker is
|
||||
// intentional - L1 writes are local and cheap, and serializing them keeps
|
||||
// goroutine count bounded under any read rate.
|
||||
func (h *HybridBackend) l1BackfillWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items best-effort then exit.
|
||||
for len(h.l1BackfillBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.l1BackfillBuffer:
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.primary.Set(writeCtx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.l1BackfillBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", redactKey(item.key), err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", redactKey(item.key))
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
@@ -572,11 +619,11 @@ func (h *HybridBackend) asyncWriteWorker() {
|
||||
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
|
||||
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
|
||||
h.logger.Debugf("Async write to L2 failed for key %s: %v", redactKey(item.key), err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
|
||||
h.logger.Debugf("Async write to L2 completed for key: %s", redactKey(item.key))
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does
|
||||
// not detonate the goroutine count. Pre-fix the code spawned one goroutine
|
||||
// per Get() L2 hit; post-fix all backfills funnel through a single worker.
|
||||
func TestHybridBackend_L1BackfillBounded(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 256,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
const burst = 1000
|
||||
|
||||
// Pre-populate L2 with `burst` distinct keys so each Get triggers a
|
||||
// fresh L1 backfill enqueue.
|
||||
for i := 0; i < burst; i++ {
|
||||
require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute))
|
||||
}
|
||||
|
||||
baseline := runtime.NumGoroutine()
|
||||
|
||||
// Issue the burst as fast as possible; the backfill worker MUST be the
|
||||
// only goroutine doing L1 writes. Allow brief slack for the test runtime
|
||||
// scheduling but anything north of +20 means goroutine leakage.
|
||||
peak := baseline
|
||||
for i := 0; i < burst; i++ {
|
||||
_, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
if g := runtime.NumGoroutine(); g > peak {
|
||||
peak = g
|
||||
}
|
||||
}
|
||||
|
||||
delta := peak - baseline
|
||||
if delta > 20 {
|
||||
t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines",
|
||||
delta, baseline, peak)
|
||||
}
|
||||
|
||||
// L1 must eventually catch up via the worker. Worker drains serially so
|
||||
// give it a generous window proportional to the burst size.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
var populated int
|
||||
for i := 0; i < burst; i++ {
|
||||
if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok {
|
||||
populated++
|
||||
}
|
||||
}
|
||||
// Be lenient: drops are acceptable under buffer pressure, just want
|
||||
// most of the keys to make it.
|
||||
if populated >= burst-int(hybrid.l1BackfillDrops.Load()) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d",
|
||||
hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load())
|
||||
}
|
||||
|
||||
// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the
|
||||
// buffer is saturated. Drops must be counted, never block, never spawn a
|
||||
// goroutine.
|
||||
func TestHybridBackend_L1BackfillFullDrops(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
// Tiny buffer + slow primary writes via failSet so the worker stays
|
||||
// blocked enough to overflow the buffer.
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 4,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
// Stop the worker from draining: cancel the underlying context so the
|
||||
// worker bails out, leaving us with a cold buffer and the queue method
|
||||
// itself responsible for drop accounting.
|
||||
hybrid.cancel()
|
||||
// Wait for worker to exit so it can't drain.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)
|
||||
}
|
||||
|
||||
assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0),
|
||||
"expected some drops when buffer is saturated and worker is stopped")
|
||||
}
|
||||
+26
@@ -0,0 +1,26 @@
|
||||
// Package backends provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// redactKey returns a short, deterministic hash prefix of a cache key for use
|
||||
// in debug/info log lines. Cache keys in this plugin can include raw access /
|
||||
// refresh / id tokens (any caller may pass an arbitrary string), and CodeQL
|
||||
// flags `key=%s` formatters as a clear-text-logging sink for HTTP-header-
|
||||
// sourced taint. The hash preserves cache-key uniqueness in logs (same key →
|
||||
// same hash, useful for correlating a problematic key across log lines) while
|
||||
// keeping the raw value out of disk-resident log streams.
|
||||
//
|
||||
// 8 hex chars (32 bits) is enough to disambiguate at human-debugging scale
|
||||
// without making the hash itself a useful lookup primitive for an attacker
|
||||
// who only has the log stream.
|
||||
func redactKey(key string) string {
|
||||
if key == "" {
|
||||
return "(empty)"
|
||||
}
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:4])
|
||||
}
|
||||
+9
-5
@@ -241,9 +241,11 @@ func (s *cacheShard) evictLRULocked() bool {
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
item, ok := element.Value.(*memoryCacheItem)
|
||||
if ok {
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -267,8 +269,10 @@ func (s *cacheShard) getOldestAccessTime() time.Time {
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
return item.accessedAt
|
||||
item, ok := element.Value.(*memoryCacheItem)
|
||||
if ok {
|
||||
return item.accessedAt
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
Vendored
+5
-2
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
TLSServerName: config.TLSServerName,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
EnableTLS: config.EnableTLS,
|
||||
TLSSkipVerify: config.TLSSkipVerify,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
@@ -345,7 +348,7 @@ func (r *RedisBackend) prefixKey(key string) string {
|
||||
|
||||
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
|
||||
// It checks context cancellation at multiple points to ensure fast abort when the
|
||||
// caller's context is cancelled (e.g., due to request timeout).
|
||||
// caller's context is canceled (e.g., due to request timeout).
|
||||
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
|
||||
maxRetries := 3
|
||||
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
|
||||
@@ -377,7 +380,7 @@ func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*Red
|
||||
err = operation(conn)
|
||||
r.pool.Put(conn)
|
||||
|
||||
// Check context after operation - if cancelled, don't bother retrying
|
||||
// Check context after operation - if canceled, don't bother retrying
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
+25
-3
@@ -2,6 +2,7 @@ package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -31,6 +32,7 @@ type ConnectionPool struct {
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
TLSServerName string // SNI server name; defaults to host(Address) when empty
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
@@ -39,6 +41,8 @@ type PoolConfig struct {
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
|
||||
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
conn, err = p.createConnection(ctx)
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if p.config.EnableTLS {
|
||||
serverName := p.config.TLSServerName
|
||||
if serverName == "" {
|
||||
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
|
||||
serverName = host
|
||||
}
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
|
||||
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
} else {
|
||||
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
+31
-1
@@ -3,6 +3,7 @@ package backends
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -201,7 +202,7 @@ func TestConnectionPool_ContextCancellation(t *testing.T) {
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get another with cancelled context
|
||||
// Try to get another with canceled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
@@ -617,4 +618,33 @@ func TestRedisConn_TooManyArguments(t *testing.T) {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// TestRedisConn_RejectOversizedArgumentBytes is a regression test for CodeQL
|
||||
// alert #10 (go/allocation-size-overflow). A single argument larger than
|
||||
// maxTotalArgBytes (64 MiB) must be rejected by the per-argument overflow
|
||||
// guard in Do() before any allocation is attempted.
|
||||
func TestRedisConn_RejectOversizedArgumentBytes(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
largeArg := strings.Repeat("x", (64<<20)+1)
|
||||
|
||||
_, err = conn.Do("SET", "k", largeArg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "arguments too large")
|
||||
}
|
||||
|
||||
+230
@@ -0,0 +1,230 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// drainRESPRequest consumes a single RESP request (array or inline) from r and
|
||||
// returns true on success. Any read error returns false.
|
||||
func drainRESPRequest(r *bufio.Reader) bool {
|
||||
header, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(header, "*") {
|
||||
return true // inline command (single line) — already consumed
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
|
||||
if err != nil || n <= 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
// Each bulk: "$len\r\n<bytes>\r\n"
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
|
||||
// answer PING with +PONG. Returns the listener address and a self-signed cert.
|
||||
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
|
||||
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
c, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
defer conn.Close()
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
if !drainRESPRequest(reader) {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("+PONG\r\n"))
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
stop = func() {
|
||||
close(stopCh)
|
||||
_ = listener.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
return listener.Addr().String(), der, stop
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
|
||||
// Regression test for issue #133 (enableTLS not propagated to client).
|
||||
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=false rejects a self-signed server cert.
|
||||
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "tls")
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
|
||||
// fails to handshake against a plain (non-TLS) listener.
|
||||
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
|
||||
plain, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer plain.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, acceptErr := plain.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: plain.Addr().String(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
|
||||
// when EnableTLS=false (default).
|
||||
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
Vendored
+15
-34
@@ -7,52 +7,34 @@ import (
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RESP (REdis Serialization Protocol) implementation
|
||||
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
|
||||
//
|
||||
// NOTE: sync.Pool was intentionally removed for Yaegi compatibility.
|
||||
// Yaegi (Traefik's Go interpreter) has issues with sync.Pool and reflection
|
||||
// that cause "reflect: call of reflect.Value.Field on zero Value" panics.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/120
|
||||
|
||||
var (
|
||||
ErrInvalidRESP = errors.New("invalid RESP response")
|
||||
ErrNilResponse = errors.New("nil response")
|
||||
)
|
||||
|
||||
// Object pools for memory optimization - reduces allocations by 50-70%
|
||||
var (
|
||||
readerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPReader{
|
||||
r: bufio.NewReaderSize(nil, 4096),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
writerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPWriter{
|
||||
w: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// RESPWriter writes RESP protocol messages
|
||||
type RESPWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
|
||||
// NewRESPWriter creates a new RESP writer
|
||||
func NewRESPWriter(w io.Writer) *RESPWriter {
|
||||
writer := writerPool.Get().(*RESPWriter)
|
||||
writer.w = w
|
||||
return writer
|
||||
return &RESPWriter{w: w}
|
||||
}
|
||||
|
||||
// Release returns the writer to the pool for reuse
|
||||
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
|
||||
func (w *RESPWriter) Release() {
|
||||
w.w = nil
|
||||
writerPool.Put(w)
|
||||
// No-op: pooling removed for Yaegi compatibility
|
||||
}
|
||||
|
||||
// WriteCommand writes a Redis command in RESP array format
|
||||
@@ -78,17 +60,16 @@ type RESPReader struct {
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
|
||||
// NewRESPReader creates a new RESP reader
|
||||
func NewRESPReader(r io.Reader) *RESPReader {
|
||||
reader := readerPool.Get().(*RESPReader)
|
||||
reader.r.Reset(r)
|
||||
return reader
|
||||
return &RESPReader{
|
||||
r: bufio.NewReaderSize(r, 4096),
|
||||
}
|
||||
}
|
||||
|
||||
// Release returns the reader to the pool for reuse
|
||||
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
|
||||
func (r *RESPReader) Release() {
|
||||
r.r.Reset(nil)
|
||||
readerPool.Put(r)
|
||||
// No-op: pooling removed for Yaegi compatibility
|
||||
}
|
||||
|
||||
// ReadResponse reads a RESP response and returns the parsed value
|
||||
|
||||
+1
-1
@@ -87,7 +87,7 @@ func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher
|
||||
// If successful, store in cache
|
||||
if call.err == nil && call.val != nil {
|
||||
// Use a background context for cache storage to ensure it completes
|
||||
// even if the original context is cancelled
|
||||
// even if the original context is canceled
|
||||
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
|
||||
cancel()
|
||||
|
||||
Vendored
+2
-2
@@ -190,7 +190,7 @@ func (c *Cache) Set(key string, value interface{}, ttl time.Duration) error {
|
||||
c.currentSize++
|
||||
atomic.AddInt64(&c.sets, 1)
|
||||
|
||||
c.logger.Debugf("Cache: Set key=%s, size=%d, ttl=%v", key, size, ttl)
|
||||
c.logger.Debugf("Cache: Set key=%s, size=%d, ttl=%v", redactKey(key), size, ttl)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -346,7 +346,7 @@ func (c *Cache) evictLRU() {
|
||||
item, _ := elem.Value.(*Item) // Safe to ignore: type assertion from known type
|
||||
c.removeItem(item.Key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
c.logger.Debugf("Cache: Evicted LRU item key=%s", item.Key)
|
||||
c.logger.Debugf("Cache: Evicted LRU item key=%s", redactKey(item.Key))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+22
@@ -0,0 +1,22 @@
|
||||
// Package cache provides the in-memory cache implementation for the Traefik
|
||||
// OIDC plugin.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// redactKey returns a short, deterministic hash prefix of a cache key for use
|
||||
// in debug/info log lines. Cache keys may include raw access / refresh / id
|
||||
// tokens (callers pass arbitrary strings) and CodeQL flags `key=%s`
|
||||
// formatters as a clear-text-logging sink for HTTP-header-sourced taint.
|
||||
// The hash preserves uniqueness in logs (same key → same hash) while keeping
|
||||
// the raw value out of disk-resident log streams.
|
||||
func redactKey(key string) string {
|
||||
if key == "" {
|
||||
return "(empty)"
|
||||
}
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:4])
|
||||
}
|
||||
Vendored
+1
-1
@@ -232,7 +232,7 @@ func (m *Manager) Close() error {
|
||||
|
||||
var firstErr error
|
||||
|
||||
if err := m.tokenCache.Close(); err != nil && firstErr == nil {
|
||||
if err := m.tokenCache.Close(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := m.metadataCache.Close(); err != nil && firstErr == nil {
|
||||
|
||||
@@ -842,10 +842,18 @@ func TestWorkerPool_TaskPanic(t *testing.T) {
|
||||
t.Error("Timeout waiting for tasks")
|
||||
}
|
||||
|
||||
// Pool should still be functional
|
||||
metrics := pool.GetMetrics()
|
||||
if metrics["tasksFailed"].(int64) < 1 {
|
||||
t.Error("Expected at least one failed task")
|
||||
// tasksFailed is incremented in the worker's deferred recover(), which runs
|
||||
// AFTER the panicking task's own `defer wg.Done()`. wg.Wait() above can
|
||||
// therefore return before the failure is recorded — reading the counter
|
||||
// immediately is a race that flakes on slow/contended CI runners. Poll until
|
||||
// the failure lands (or time out).
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for pool.GetMetrics()["tasksFailed"].(int64) < 1 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Error("Expected at least one failed task")
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -397,7 +397,7 @@ func (wp *WorkerPool) Submit(task func()) error {
|
||||
}
|
||||
|
||||
// worker is the main worker routine
|
||||
func (wp *WorkerPool) worker(id int) {
|
||||
func (wp *WorkerPool) worker(_ int) {
|
||||
defer wp.workerWg.Done()
|
||||
|
||||
for {
|
||||
|
||||
@@ -173,7 +173,7 @@ func (m *FeatureManager) LoadFromEnv() {
|
||||
for name, flag := range flags {
|
||||
envVar := "FEATURE_" + name
|
||||
if value := os.Getenv(envVar); value != "" {
|
||||
enabled := strings.ToLower(value) == "true" || value == "1"
|
||||
enabled := strings.EqualFold(value, "true") || value == "1"
|
||||
flag.enabled.Store(enabled)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
|
||||
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if strings.ToLower(scope) != ScopeOfflineAccess {
|
||||
if !strings.EqualFold(scope, ScopeOfflineAccess) {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +147,8 @@ func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
|
||||
return p
|
||||
}
|
||||
case ProviderTypeKeycloak:
|
||||
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/auth/realms/") {
|
||||
// Match both Keycloak <17 (`/auth/realms/`) and 17+ (`/realms/`).
|
||||
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/realms/") {
|
||||
return p
|
||||
}
|
||||
case ProviderTypeAWSCognito:
|
||||
|
||||
@@ -225,10 +225,15 @@ func TestProviderRegistry_DetectProvider(t *testing.T) {
|
||||
expected: oktaProvider,
|
||||
},
|
||||
{
|
||||
name: "Keycloak provider detection",
|
||||
name: "Keycloak provider detection (legacy /auth/realms/)",
|
||||
issuerURL: "https://auth.example.com/auth/realms/master",
|
||||
expected: keycloakProvider,
|
||||
},
|
||||
{
|
||||
name: "Keycloak provider detection (modern /realms/, KC 17+)",
|
||||
issuerURL: "https://auth.example.com/realms/master",
|
||||
expected: keycloakProvider,
|
||||
},
|
||||
{
|
||||
name: "AWS Cognito provider detection",
|
||||
issuerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example",
|
||||
|
||||
@@ -18,16 +18,17 @@ func GetProviderWarnings(providerType ProviderType) []ProviderWarning {
|
||||
|
||||
switch providerType {
|
||||
case ProviderTypeGitHub:
|
||||
warnings = append(warnings, ProviderWarning{
|
||||
ProviderType: ProviderTypeGitHub,
|
||||
Level: "warning",
|
||||
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
|
||||
})
|
||||
warnings = append(warnings, ProviderWarning{
|
||||
ProviderType: ProviderTypeGitHub,
|
||||
Level: "info",
|
||||
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
|
||||
})
|
||||
warnings = append(warnings,
|
||||
ProviderWarning{
|
||||
ProviderType: ProviderTypeGitHub,
|
||||
Level: "warning",
|
||||
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
|
||||
},
|
||||
ProviderWarning{
|
||||
ProviderType: ProviderTypeGitHub,
|
||||
Level: "info",
|
||||
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
|
||||
})
|
||||
|
||||
case ProviderTypeAuth0:
|
||||
warnings = append(warnings, ProviderWarning{
|
||||
|
||||
@@ -116,7 +116,7 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
|
||||
// Continue to next attempt
|
||||
case <-ctx.Done():
|
||||
re.RecordFailure()
|
||||
return fmt.Errorf("retry cancelled: %w", ctx.Err())
|
||||
return fmt.Errorf("retry canceled: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,7 +301,7 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
allMetrics["summary"] = map[string]interface{}{
|
||||
summary := map[string]interface{}{
|
||||
"totalMechanisms": len(rm.mechanisms),
|
||||
"totalRequests": totalRequests,
|
||||
"totalSuccesses": totalSuccesses,
|
||||
@@ -310,8 +310,9 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
|
||||
|
||||
if totalRequests > 0 {
|
||||
successRate := float64(totalSuccesses) / float64(totalRequests) * 100
|
||||
allMetrics["summary"].(map[string]interface{})["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
|
||||
summary["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
|
||||
}
|
||||
allMetrics["summary"] = summary
|
||||
|
||||
return allMetrics
|
||||
}
|
||||
|
||||
@@ -223,7 +223,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelled(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
if execErr == nil {
|
||||
t.Error("Expected error when context is cancelled")
|
||||
t.Error("Expected error when context is canceled")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelledBeforeStart(t *testing
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when context is already cancelled")
|
||||
t.Error("Expected error when context is already canceled")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,7 +282,7 @@ func TestRetryExecutor_isRetryableError(t *testing.T) {
|
||||
{name: "timeout", err: errors.New("TIMEOUT"), expected: true}, // case insensitive
|
||||
{name: "EOF", err: errors.New("EOF"), expected: false},
|
||||
{name: "random error", err: errors.New("something else"), expected: false},
|
||||
{name: "context cancelled", err: context.Canceled, expected: false},
|
||||
{name: "context canceled", err: context.Canceled, expected: false},
|
||||
{name: "context deadline exceeded", err: context.DeadlineExceeded, expected: false},
|
||||
}
|
||||
|
||||
|
||||
+23
-1
@@ -155,12 +155,34 @@ func DetermineScheme(req *http.Request, forceHTTPS bool) string {
|
||||
// It checks X-Forwarded-Host header first (for proxy scenarios),
|
||||
// then falls back to req.Host.
|
||||
func DetermineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
if host := sanitizeForwardedHost(req.Header.Get("X-Forwarded-Host")); host != "" {
|
||||
return host
|
||||
}
|
||||
return req.Host
|
||||
}
|
||||
|
||||
// sanitizeForwardedHost returns a single, well-formed host from a (possibly
|
||||
// comma-separated) X-Forwarded-Host header, or "" if none is usable. It takes
|
||||
// only the first value and rejects whitespace and control characters, so a
|
||||
// crafted header cannot inject CRLF, smuggle a second host, or otherwise poison
|
||||
// the redirect URLs built from the result.
|
||||
func sanitizeForwardedHost(v string) string {
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
if i := strings.IndexByte(v, ','); i >= 0 {
|
||||
v = v[:i]
|
||||
}
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.IndexFunc(v, func(r rune) bool { return r < 0x20 || r == 0x7f || r == ' ' }) >= 0 {
|
||||
return ""
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// BuildFullURL constructs a URL from scheme, host, and path components.
|
||||
// It handles absolute URLs (returning them as-is) and ensures paths have leading slashes.
|
||||
func BuildFullURL(scheme, host, path string) string {
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
|
||||
// the fix for issue #132: token refresh path hardcoded the "email" claim and
|
||||
// ignored the configured userIdentifierClaim. Keycloak users without an email
|
||||
// claim (using sub or another identifier) were being kicked out on refresh
|
||||
// even though their initial login worked.
|
||||
//
|
||||
// The callback path (auth_flow.go) already honored userIdentifierClaim with
|
||||
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
|
||||
// after PR #100 (commit a316a98).
|
||||
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
name string
|
||||
userIdentifierClaim string
|
||||
expectedIdentifier string
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "sub claim configured, only sub present (Keycloak no-email case)",
|
||||
userIdentifierClaim: "sub",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-keycloak-12345",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user-uuid-keycloak-12345",
|
||||
},
|
||||
{
|
||||
name: "preferred_username configured, claim present",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"preferred_username": "alice",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "alice",
|
||||
},
|
||||
{
|
||||
name: "configured claim missing, falls back to sub",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "fallback-sub-id",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "fallback-sub-id",
|
||||
},
|
||||
{
|
||||
name: "email default, email present (backward compatibility)",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "email default, no email and no sub - refresh fails",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: false,
|
||||
expectedIdentifier: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager(
|
||||
"test-encryption-key-32-bytes-long!!",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
NewLogger("error"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("session manager: %v", err)
|
||||
}
|
||||
defer sessionManager.Shutdown()
|
||||
|
||||
capturedClaims := tt.claims
|
||||
tOidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
userIdentifierClaim: tt.userIdentifierClaim,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: &EnhancedMockTokenExchanger{
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
IDToken: "new-id-token-jwt",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
},
|
||||
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
|
||||
extractClaimsFunc: func(token string) (map[string]any, error) {
|
||||
return capturedClaims, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("get session: %v", err)
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
session.SetRefreshToken("initial-refresh-token")
|
||||
|
||||
refreshed := tOidc.refreshToken(rw, req, session)
|
||||
|
||||
if refreshed != tt.expectSuccess {
|
||||
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
|
||||
}
|
||||
|
||||
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
|
||||
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// signGraphStyleAccessToken builds a JWT in Microsoft's Graph proprietary
|
||||
// nonce-header form: bytes that get signed contain the SHA256 hash of the
|
||||
// nonce, while the wire token ships the original nonce. A standard JWS
|
||||
// verifier always rejects these with `crypto/rsa: verification error`, which
|
||||
// is why Microsoft documents Graph access tokens as opaque to client apps:
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
|
||||
// "you can't validate tokens for Microsoft Graph according to these rules
|
||||
// due to their proprietary format"
|
||||
func signGraphStyleAccessToken(t *testing.T, key *rsa.PrivateKey, kid, originalNonce string, claims map[string]any) string {
|
||||
t.Helper()
|
||||
|
||||
wireHeader := map[string]any{
|
||||
"alg": "RS256",
|
||||
"kid": kid,
|
||||
"typ": "JWT",
|
||||
"nonce": originalNonce,
|
||||
}
|
||||
wireHeaderJSON, err := json.Marshal(wireHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
hashed := sha256.Sum256([]byte(originalNonce))
|
||||
signedHeader := map[string]any{
|
||||
"alg": "RS256",
|
||||
"kid": kid,
|
||||
"typ": "JWT",
|
||||
"nonce": fmt.Sprintf("%x", hashed),
|
||||
}
|
||||
signedHeaderJSON, err := json.Marshal(signedHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
wireHeaderB64 := base64.RawURLEncoding.EncodeToString(wireHeaderJSON)
|
||||
signedHeaderB64 := base64.RawURLEncoding.EncodeToString(signedHeaderJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
|
||||
signedInput := signedHeaderB64 + "." + claimsB64
|
||||
hSign := sha256.Sum256([]byte(signedInput))
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hSign[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
return wireHeaderB64 + "." + claimsB64 + "." + base64.RawURLEncoding.EncodeToString(sig)
|
||||
}
|
||||
|
||||
// newAzureFollowupOIDC produces a TraefikOidc instance wired for an Azure
|
||||
// AD tenant with a captured error log buffer. Used by the issue #134 followup
|
||||
// tests to assert log behavior during validateAzureTokens flows.
|
||||
func newAzureFollowupOIDC(t *testing.T, jwks *JWKSet) (*TraefikOidc, *bytes.Buffer) {
|
||||
t.Helper()
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
jwkCache: &MockJWKCache{JWKS: jwks},
|
||||
tokenCache: tokenCache,
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
oidc.tokenVerifier = oidc
|
||||
oidc.jwtVerifier = oidc
|
||||
require.True(t, oidc.isAzureProvider(), "fixture must be detected as Azure provider")
|
||||
return oidc, errBuf
|
||||
}
|
||||
|
||||
// authedSessionWithTokens returns a SessionData populated with the supplied
|
||||
// access and ID tokens, marked authenticated and recently created. The
|
||||
// SessionManager carries a real ChunkManager so that GetAccessToken /
|
||||
// GetIDToken / GetRefreshToken behave like the production code path.
|
||||
func authedSessionWithTokens(t *testing.T, accessToken, idToken string) *SessionData {
|
||||
t.Helper()
|
||||
|
||||
chunkLogger := NewLogger("error")
|
||||
chunkManager := NewChunkManager(chunkLogger)
|
||||
t.Cleanup(chunkManager.Shutdown)
|
||||
|
||||
sd := CreateMockSessionData()
|
||||
sd.manager = &SessionManager{
|
||||
sessionMaxAge: 24 * time.Hour,
|
||||
chunkManager: chunkManager,
|
||||
logger: chunkLogger,
|
||||
}
|
||||
|
||||
sd.mainSession = sessions.NewSession(nil, "main")
|
||||
sd.mainSession.Values["authenticated"] = true
|
||||
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
||||
|
||||
sd.accessSession = sessions.NewSession(nil, "access")
|
||||
sd.accessSession.Values["token"] = accessToken
|
||||
sd.accessSession.Values["compressed"] = false
|
||||
|
||||
sd.idTokenSession = sessions.NewSession(nil, "id")
|
||||
sd.idTokenSession.Values["token"] = idToken
|
||||
sd.idTokenSession.Values["compressed"] = false
|
||||
|
||||
sd.refreshSession = sessions.NewSession(nil, "refresh")
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
sd.refreshSession.Values["compressed"] = false
|
||||
|
||||
return sd
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_GraphAccessTokenReproducesUsersError sanity-checks
|
||||
// that our crafted Graph-style token reproduces the exact rsa error string
|
||||
// quoted on the issue thread (dada-engineer 2026-05-08, friek 2026-05-11).
|
||||
//
|
||||
// Sanity test: must always pass, regardless of the issue #134 followup fix.
|
||||
// It exists so a future contributor does not accidentally weaken the
|
||||
// reproducer and assume the followup fix is no longer needed.
|
||||
func TestIssue134_Followup_GraphAccessTokenReproducesUsersError(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-followup-kid"
|
||||
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
parsedJWT, err := parseJWT(graphToken)
|
||||
require.NoError(t, err)
|
||||
pubKey := &rsaKey.PublicKey
|
||||
alg, _ := parsedJWT.Header["alg"].(string)
|
||||
verifyErr := verifySignatureWithKey(graphToken, pubKey, alg)
|
||||
require.Error(t, verifyErr)
|
||||
assert.Contains(t, verifyErr.Error(), "crypto/rsa: verification error",
|
||||
"reproducer must emit the exact error string reported on issue #134")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken is the
|
||||
// failing-then-passing test for the followup fix.
|
||||
//
|
||||
// Symptom (before fix): validateAzureTokens calls verifyToken on every
|
||||
// JWT-shaped access token. For Microsoft Graph access tokens (the default
|
||||
// when no custom resource is registered), verification always fails with
|
||||
// `crypto/rsa: verification error`, generating two error log lines per
|
||||
// request:
|
||||
//
|
||||
// UNKNOWN token verification failed: signature verification failed:
|
||||
// crypto/rsa: verification error
|
||||
// DIAGNOSTIC: Signature verification failed for kid=<kid>, alg=RS256:
|
||||
// crypto/rsa: verification error
|
||||
//
|
||||
// Microsoft's own documentation tells client apps not to validate Graph
|
||||
// access tokens. The fix matches that guidance: when an Azure access token
|
||||
// carries Microsoft's proprietary `nonce` JWT header, treat it as opaque
|
||||
// (skip JWT verification, fall through to ID token validation).
|
||||
func TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-followup-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
now := time.Now()
|
||||
exp := now.Add(time.Hour).Unix()
|
||||
|
||||
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-azure-graph", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": exp,
|
||||
"iat": now.Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"appid": "test-client-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"email": "user@example.com",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "id-token-jti-followup",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, graphAccessToken, idToken)
|
||||
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
|
||||
|
||||
output := errBuf.String()
|
||||
assert.NotContains(t, output, "crypto/rsa: verification error",
|
||||
"validateAzureTokens must not log rsa verification error for Graph-style access tokens; got: %q", output)
|
||||
assert.NotContains(t, output, "DIAGNOSTIC: Signature verification failed",
|
||||
"DIAGNOSTIC line must not fire for Graph-style access tokens; got: %q", output)
|
||||
assert.NotContains(t, output, "UNKNOWN token verification failed",
|
||||
"UNKNOWN classification log must not fire for Graph-style access tokens; got: %q", output)
|
||||
|
||||
assert.True(t, authenticated, "session must remain authenticated via the ID token fallback")
|
||||
assert.False(t, needsRefresh, "valid ID token must not signal a refresh need")
|
||||
assert.False(t, expired, "valid ID token must not be reported as expired")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection covers the
|
||||
// classifier added by the followup fix. Pure-function unit test for the
|
||||
// Microsoft proprietary marker we rely on (nonce in JWT header).
|
||||
func TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-detection-kid"
|
||||
standardToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
oidc, _ := newAzureFollowupOIDC(t, &JWKSet{})
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
token string
|
||||
wantUnverified bool
|
||||
}{
|
||||
{name: "standard JWT without nonce header", token: standardToken, wantUnverified: false},
|
||||
{name: "Microsoft proprietary token (nonce in header)", token: graphToken, wantUnverified: true},
|
||||
{name: "garbage token treated as unverifiable", token: "not-a-jwt-at-all", wantUnverified: true},
|
||||
{name: "empty token treated as unverifiable", token: "", wantUnverified: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := oidc.isUnverifiableAzureAccessToken(tc.token)
|
||||
assert.Equal(t, tc.wantUnverified, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_StandardAzureAccessTokenStillVerifies guards against
|
||||
// regression in the happy path: an access token issued for our own clientID
|
||||
// (custom Azure-registered API) — no proprietary nonce header, signed normally
|
||||
// — must still flow through the standard verification path and authenticate.
|
||||
func TestIssue134_Followup_StandardAzureAccessTokenStillVerifies(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-standard-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
now := time.Now()
|
||||
exp := now.Add(time.Hour).Unix()
|
||||
|
||||
// Custom-resource access token: aud points to the app, no nonce header.
|
||||
accessToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "api.read",
|
||||
"jti": "standard-access-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": now.Add(-2 * time.Minute).Unix(),
|
||||
"nbf": now.Add(-2 * time.Minute).Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"email": "user@example.com",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "standard-id-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, accessToken, idToken)
|
||||
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
|
||||
|
||||
assert.True(t, authenticated, "standard Azure access token must verify and authenticate")
|
||||
assert.False(t, needsRefresh)
|
||||
assert.False(t, expired)
|
||||
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error",
|
||||
"standard Azure token must not produce signature errors")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_GraphAccessTokenWithoutIDToken covers the edge where
|
||||
// the session has only a Graph access token (no ID token). The classifier must
|
||||
// preserve the existing "treat as opaque" semantics for backward compatibility:
|
||||
// authenticated=true even when there is no ID token to verify.
|
||||
func TestIssue134_Followup_GraphAccessTokenWithoutIDToken(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-no-idt-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-no-idt", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "00000003-0000-0000-c000-000000000000",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "user-azure-id",
|
||||
"scp": "User.Read",
|
||||
})
|
||||
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, graphAccessToken, "")
|
||||
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
|
||||
|
||||
assert.True(t, authenticated, "Graph token without ID token must remain authenticated (matches existing opaque-token semantics)")
|
||||
assert.False(t, needsRefresh)
|
||||
assert.False(t, expired)
|
||||
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error")
|
||||
}
|
||||
|
||||
// TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification proves
|
||||
// the classifier is not a security regression. An attacker who forges a JWT
|
||||
// with a `nonce` JWT header (Microsoft's proprietary marker) but a payload
|
||||
// claiming `aud=our-clientID` should NOT gain authenticated status simply by
|
||||
// triggering the "treat as opaque" branch.
|
||||
//
|
||||
// This is the confused-deputy guardrail Microsoft warns about
|
||||
// (https://cwe.mitre.org/data/definitions/441.html): we treat the access token
|
||||
// as opaque, which means we DO NOT authorize from it — authorization comes
|
||||
// only from a separately verifiable ID token. An attacker without a valid ID
|
||||
// token must not be authenticated.
|
||||
func TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification(t *testing.T) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
attackerKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
const kid = "azure-attack-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
|
||||
}
|
||||
jwks := &JWKSet{Keys: []JWK{jwk}}
|
||||
|
||||
// Forged: attacker uses their OWN key, sets aud = our clientID, plants a
|
||||
// `nonce` header to trip the opaque-detection path.
|
||||
forgedAccessToken := signGraphStyleAccessToken(t, attackerKey, kid, "attacker-nonce", map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "attacker",
|
||||
"scp": "admin",
|
||||
})
|
||||
|
||||
// Forged ID token signed with the attacker's key — must fail verification
|
||||
// against the tenant JWKS.
|
||||
forgedIDToken, err := createTestJWT(attackerKey, "RS256", kid, map[string]any{
|
||||
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Minute).Unix(),
|
||||
"nbf": time.Now().Add(-2 * time.Minute).Unix(),
|
||||
"sub": "attacker",
|
||||
"email": "attacker@evil.example",
|
||||
"nonce": "id-token-oidc-nonce",
|
||||
"jti": "attacker-id-jti",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc, _ := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, forgedAccessToken, forgedIDToken)
|
||||
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, _, _ := oidc.validateAzureTokensRS(rs)
|
||||
assert.False(t, authenticated,
|
||||
"attacker's forged tokens must not authenticate even when the access token has a nonce header — ID token verification rejects the wrong-key signature")
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError reproduces and
|
||||
// verifies the fix for issue #134.
|
||||
//
|
||||
// Symptom (before fix): with a Redis backend wired into UniversalCache,
|
||||
// caching the parsed *parsedJWKS triggered:
|
||||
//
|
||||
// json: cannot unmarshal number 2251513...
|
||||
// into Go value of type float64
|
||||
//
|
||||
// Root cause: under yaegi, json.Marshal of a struct exposes unexported
|
||||
// fields with an X-prefixed name. parsedJWKS{ keys map[string]crypto.PublicKey }
|
||||
// thus serialized the inner *rsa.PublicKey, whose modulus *big.Int marshals
|
||||
// as a JSON number hundreds of digits long. On read, json.Unmarshal into
|
||||
// interface{} parses numbers as float64, which cannot represent that range.
|
||||
// The user saw the error log on every request even though auth still worked
|
||||
// (fallback path rebuilt the keys in memory).
|
||||
//
|
||||
// Fix: route both *JWKSet and *parsedJWKS through SetLocal/GetLocal — the
|
||||
// distributed backend never sees them.
|
||||
func TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "issue134:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
const kid = "azure-test-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
|
||||
}
|
||||
|
||||
var fetchCount int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&fetchCount, 1)
|
||||
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
infoBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(infoBuf, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
jwkCache := &JWKCache{cache: cache}
|
||||
ctx := context.Background()
|
||||
|
||||
pub1, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err, "first GetPublicKey should succeed")
|
||||
require.NotNil(t, pub1)
|
||||
gotRSA, ok := pub1.(*rsa.PublicKey)
|
||||
require.True(t, ok, "returned key should be *rsa.PublicKey, got %T", pub1)
|
||||
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N), "modulus must survive intact")
|
||||
assert.Equal(t, rsaKey.E, gotRSA.E, "exponent must survive intact")
|
||||
|
||||
pub2, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err, "second GetPublicKey should succeed")
|
||||
require.True(t, samePublicKey(pub1, pub2), "second call must return the same parsed key (cache hit)")
|
||||
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&fetchCount),
|
||||
"upstream JWKS endpoint must be hit exactly once; second call must be served from local cache")
|
||||
|
||||
errOutput := errBuf.String()
|
||||
assert.NotContains(t, errOutput, "Failed to deserialize",
|
||||
"deserialize error must not appear with the fix in place; got: %s", errOutput)
|
||||
assert.NotContains(t, errOutput, "into Go value of type float64",
|
||||
"float64 unmarshal error must not appear; got: %s", errOutput)
|
||||
|
||||
parsedKey := server.URL + parsedKeysSuffix
|
||||
jwksKey := server.URL
|
||||
for _, k := range []string{cache.prefixKey(parsedKey), cache.prefixKey(jwksKey)} {
|
||||
fullKey := redisCfg.RedisPrefix + k
|
||||
assert.False(t, mr.Exists(fullKey),
|
||||
"key %q must not exist in Redis (local-only caching); got %v", fullKey, mr.Keys())
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue134_StalePoisonedRedisDataIgnored verifies that pre-existing bad
|
||||
// data left in Redis under a JWK :parsed key from a prior buggy version is
|
||||
// ignored: the local-only fix never reads that key, so no log spam, and the
|
||||
// fallback path returns a real *rsa.PublicKey.
|
||||
func TestIssue134_StalePoisonedRedisDataIgnored(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "issue134stale:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
const kid = "azure-test-kid"
|
||||
jwk := JWK{
|
||||
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pre-poison Redis with the kind of payload the old buggy path would have
|
||||
// produced (huge unquoted JSON number for the modulus). With the fix the
|
||||
// JWKCache must not even read this key.
|
||||
poisoned := []byte("\x01" + strings.Replace(
|
||||
`{"Xkeys":{"azure-test-kid":{"N":NUMBER,"E":65537}}}`,
|
||||
"NUMBER", rsaKey.N.String(), 1,
|
||||
))
|
||||
parsedRedisKey := redisCfg.RedisPrefix + "jwk:" + server.URL + parsedKeysSuffix
|
||||
require.NoError(t, mr.Set(parsedRedisKey, string(poisoned)))
|
||||
|
||||
errBuf := &bytes.Buffer{}
|
||||
logger := &Logger{
|
||||
logError: log.New(errBuf, "", 0),
|
||||
logInfo: log.New(io.Discard, "", 0),
|
||||
logDebug: log.New(io.Discard, "", 0),
|
||||
}
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
jwkCache := &JWKCache{cache: cache}
|
||||
pub, err := jwkCache.GetPublicKey(context.Background(), server.URL, kid, http.DefaultClient)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pub)
|
||||
gotRSA, ok := pub.(*rsa.PublicKey)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N))
|
||||
|
||||
assert.NotContains(t, errBuf.String(), "Failed to deserialize",
|
||||
"poisoned Redis entry must not be touched; got error log: %s", errBuf.String())
|
||||
}
|
||||
|
||||
// TestIssue134_SetLocalGetLocalSkipBackend verifies the new SetLocal/GetLocal
|
||||
// pair never reads or writes the configured backend.
|
||||
func TestIssue134_SetLocalGetLocalSkipBackend(t *testing.T) {
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisCfg := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisCfg.RedisPrefix = "local:"
|
||||
backend, err := backends.NewRedisBackend(redisCfg)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 10,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
type unsafeShape struct {
|
||||
hidden map[string]interface{}
|
||||
}
|
||||
val := &unsafeShape{hidden: map[string]interface{}{"k": 1}}
|
||||
|
||||
require.NoError(t, cache.SetLocal("local-key", val, 1*time.Hour))
|
||||
|
||||
got, found := cache.GetLocal("local-key")
|
||||
require.True(t, found)
|
||||
assert.Same(t, val, got, "GetLocal must return the exact pointer stored, no JSON round-trip")
|
||||
|
||||
for _, k := range mr.Keys() {
|
||||
assert.NotContains(t, k, "local-key",
|
||||
"SetLocal must not write to Redis; found key %q (all keys: %v)", k, mr.Keys())
|
||||
}
|
||||
|
||||
cache.mu.Lock()
|
||||
delete(cache.items, "local-key")
|
||||
cache.lruList.Init()
|
||||
cache.currentSize = 0
|
||||
cache.currentMemory = 0
|
||||
cache.mu.Unlock()
|
||||
|
||||
_, found = cache.GetLocal("local-key")
|
||||
assert.False(t, found, "GetLocal must not fall back to backend after local cache cleared")
|
||||
}
|
||||
|
||||
// big2bytes returns the big-endian byte slice for a positive int.
|
||||
func big2bytes(e int) []byte {
|
||||
if e <= 0 {
|
||||
return []byte{}
|
||||
}
|
||||
var buf []byte
|
||||
for e > 0 {
|
||||
buf = append([]byte{byte(e & 0xff)}, buf...)
|
||||
e >>= 8
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// samePublicKey reports whether two crypto.PublicKey instances represent the
|
||||
// same RSA key, used to confirm cache hits return identical reconstructed
|
||||
// keys.
|
||||
func samePublicKey(a, b interface{}) bool {
|
||||
ar, ok1 := a.(*rsa.PublicKey)
|
||||
br, ok2 := b.(*rsa.PublicKey)
|
||||
if !ok1 || !ok2 {
|
||||
return false
|
||||
}
|
||||
return ar.N.Cmp(br.N) == 0 && ar.E == br.E
|
||||
}
|
||||
@@ -0,0 +1,925 @@
|
||||
package traefikoidc
|
||||
|
||||
// issue135_regression_test.go — regression tests for RFC 7523 private_key_jwt
|
||||
// client authentication (issue #135).
|
||||
//
|
||||
// These tests guard:
|
||||
// - Correct JWT construction and cryptographic signature for all supported
|
||||
// algorithms (RS*/PS*/ES*).
|
||||
// - Proper validation of alg/key type combinations and empty-kid rejection.
|
||||
// - JTI uniqueness across concurrent calls.
|
||||
// - PEM variant tolerance (PKCS#8, PKCS#1, SEC1).
|
||||
// - Config.Validate() behavior for all private_key_jwt configuration paths.
|
||||
// - buildClientAssertionSignerFromConfig: inline PEM, file-backed PEM, default alg.
|
||||
// - Wire-up in exchangeTokens: assertion fields sent, client_secret absent.
|
||||
// - Wire-up in RevokeTokenWithProvider: assertion fields sent, audience = tokenURL.
|
||||
// - Back-compat: client_secret_post path unchanged when clientAssertion == nil.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ── A. Signer unit tests ──────────────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_SignerRSAFamily verifies that NewClientAssertionSigner + Sign
|
||||
// produces a well-formed, cryptographically valid JWT for every RSA-family
|
||||
// algorithm (RS256/RS384/RS512/PS256/PS384/PS512).
|
||||
func TestIssue135_SignerRSAFamily(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
cases := []struct {
|
||||
alg string
|
||||
hashFn func([]byte) []byte
|
||||
isPS bool
|
||||
hash crypto.Hash
|
||||
}{
|
||||
{"RS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, false, crypto.SHA256},
|
||||
{"RS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, false, crypto.SHA384},
|
||||
{"RS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, false, crypto.SHA512},
|
||||
{"PS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, true, crypto.SHA256},
|
||||
{"PS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, true, crypto.SHA384},
|
||||
{"PS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, true, crypto.SHA512},
|
||||
}
|
||||
|
||||
const (
|
||||
audience = "https://example.com/token"
|
||||
clientID = "client-abc"
|
||||
kid = "kid-1"
|
||||
)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.alg, func(t *testing.T) {
|
||||
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign(audience, clientID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3, "JWT must have three dot-separated parts")
|
||||
|
||||
// Decode and check header.
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, tc.alg, hdr["alg"])
|
||||
assert.Equal(t, "JWT", hdr["typ"])
|
||||
assert.Equal(t, kid, hdr["kid"])
|
||||
|
||||
// Decode and check claims.
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, clientID, clms["iss"])
|
||||
assert.Equal(t, clientID, clms["sub"])
|
||||
assert.Equal(t, audience, clms["aud"])
|
||||
|
||||
iat, ok := clms["iat"].(float64)
|
||||
require.True(t, ok, "iat must be numeric")
|
||||
exp, ok := clms["exp"].(float64)
|
||||
require.True(t, ok, "exp must be numeric")
|
||||
assert.InDelta(t, 60, exp-iat, 2, "exp-iat must equal ~60s")
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
assert.True(t, iat <= now+2 && iat >= now-5, "iat must be current time ±5s")
|
||||
|
||||
jti, ok := clms["jti"].(string)
|
||||
require.True(t, ok, "jti must be a string")
|
||||
assert.Len(t, jti, 32, "jti must be 32-char hex (16 bytes → hex)")
|
||||
|
||||
// Verify cryptographic signature.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := tc.hashFn([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
pub := &rsaKey.PublicKey
|
||||
if tc.isPS {
|
||||
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: tc.hash}
|
||||
assert.NoError(t, rsa.VerifyPSS(pub, tc.hash, digest, sigBytes, opts),
|
||||
"PSS signature verification failed for %s", tc.alg)
|
||||
} else {
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(pub, tc.hash, digest, sigBytes),
|
||||
"PKCS1v15 signature verification failed for %s", tc.alg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerECDSAFamily verifies correct JWT production for all
|
||||
// ECDSA algorithms (ES256/ES384/ES512) including that the signature is the
|
||||
// raw r||s encoding (not ASN.1 DER) and is verifiable with the matching key.
|
||||
func TestIssue135_SignerECDSAFamily(t *testing.T) {
|
||||
cases := []struct {
|
||||
alg string
|
||||
curve elliptic.Curve
|
||||
hashFn func([]byte) []byte
|
||||
hash crypto.Hash
|
||||
}{
|
||||
{"ES256", elliptic.P256(), func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, crypto.SHA256},
|
||||
{"ES384", elliptic.P384(), func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, crypto.SHA384},
|
||||
{"ES512", elliptic.P521(), func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, crypto.SHA512},
|
||||
}
|
||||
|
||||
const (
|
||||
audience = "https://idp.example.com/token"
|
||||
clientID = "ec-client"
|
||||
kid = "ec-kid"
|
||||
)
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.alg, func(t *testing.T) {
|
||||
ecKey, err := ecdsa.GenerateKey(tc.curve, rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign(audience, clientID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
byteLen := (tc.curve.Params().BitSize + 7) / 8
|
||||
assert.Len(t, sigBytes, 2*byteLen,
|
||||
"ECDSA signature must be raw r||s (2×%d bytes for %s)", byteLen, tc.alg)
|
||||
|
||||
r := new(big.Int).SetBytes(sigBytes[:byteLen])
|
||||
s := new(big.Int).SetBytes(sigBytes[byteLen:])
|
||||
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := tc.hashFn([]byte(sigInput))
|
||||
|
||||
ok := ecdsa.Verify(&ecKey.PublicKey, digest, r, s)
|
||||
assert.True(t, ok, "ECDSA signature verification failed for %s", tc.alg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerRejectsAlgKeyMismatch verifies that the signer constructor
|
||||
// rejects type mismatches between key type and algorithm, unknown algorithms,
|
||||
// and an empty kid.
|
||||
func TestIssue135_SignerRejectsAlgKeyMismatch(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
rsaPEM := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
ecPEM := encodeECPKCS8(t, ecKey)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
pemBytes []byte
|
||||
alg string
|
||||
kid string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "RSA key with ES256",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "ES256",
|
||||
kid: "k1",
|
||||
wantErr: "EC key",
|
||||
},
|
||||
{
|
||||
name: "EC key with RS256",
|
||||
pemBytes: ecPEM,
|
||||
alg: "RS256",
|
||||
kid: "k1",
|
||||
wantErr: "RSA key",
|
||||
},
|
||||
{
|
||||
name: "unknown alg HS256",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "HS256",
|
||||
kid: "k1",
|
||||
wantErr: "unsupported",
|
||||
},
|
||||
{
|
||||
name: "empty kid",
|
||||
pemBytes: rsaPEM,
|
||||
alg: "RS256",
|
||||
kid: "",
|
||||
wantErr: "kid must not be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewClientAssertionSigner(tc.pemBytes, tc.alg, tc.kid)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.wantErr),
|
||||
"error should mention %q", tc.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerJTIUniqueness signs 50 assertions with the same signer
|
||||
// and asserts all jti values are distinct. Guards against broken entropy reuse.
|
||||
func TestIssue135_SignerJTIUniqueness(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "jti-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
seen := make(map[string]bool, 50)
|
||||
for i := range 50 {
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "client-x")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
jti, ok := clms["jti"].(string)
|
||||
require.True(t, ok)
|
||||
assert.False(t, seen[jti], "jti %q was reused at iteration %d", jti, i)
|
||||
seen[jti] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_SignerPEMVariants confirms that all PEM block types understood
|
||||
// by NewClientAssertionSigner are parsed correctly: PKCS#8 ("PRIVATE KEY"),
|
||||
// PKCS#1 ("RSA PRIVATE KEY"), and SEC1 ("EC PRIVATE KEY").
|
||||
func TestIssue135_SignerPEMVariants(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("RSA PKCS8", func(t *testing.T) {
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
|
||||
require.NoError(t, err)
|
||||
assertValidRSAJWT(t, rsaKey, signer, "RS256")
|
||||
})
|
||||
|
||||
t.Run("RSA PKCS1", func(t *testing.T) {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: der})
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
|
||||
require.NoError(t, err)
|
||||
assertValidRSAJWT(t, rsaKey, signer, "RS256")
|
||||
})
|
||||
|
||||
t.Run("EC PKCS8", func(t *testing.T) {
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
|
||||
require.NoError(t, err)
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "cid")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
})
|
||||
|
||||
t.Run("EC SEC1", func(t *testing.T) {
|
||||
der, err := x509.MarshalECPrivateKey(ecKey)
|
||||
require.NoError(t, err)
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
|
||||
require.NoError(t, err)
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "cid")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
})
|
||||
}
|
||||
|
||||
// ── B. Config validation ──────────────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_ConfigValidation table-drives Config.Validate() for every
|
||||
// client-authentication-related validation branch.
|
||||
func TestIssue135_ConfigValidation(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
validPEM := string(encodeRSAPKCS8(t, rsaKey))
|
||||
|
||||
// baseConfig returns the minimum valid config, modified per test case.
|
||||
base := func() *Config {
|
||||
return &Config{
|
||||
ProviderURL: "https://idp.example.com",
|
||||
CallbackURL: "/cb",
|
||||
ClientID: "cid",
|
||||
ClientSecret: "secret",
|
||||
SessionEncryptionKey: "01234567890123456789012345678901", // 32 chars
|
||||
RateLimit: 100,
|
||||
}
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
wantErr string // empty = expect nil error
|
||||
}{
|
||||
{
|
||||
name: "default empty method + secret ok",
|
||||
mutate: func(c *Config) { /* nothing extra */ },
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "explicit client_secret_post + secret ok",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "client_secret_post"
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt inline key + kid ok",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt no key at all",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "clientAssertionPrivateKey",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt both inline and path",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyPath = "/tmp/key.pem"
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
},
|
||||
wantErr: "only one of",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt key but no kid",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
},
|
||||
wantErr: "clientAssertionKeyID",
|
||||
},
|
||||
{
|
||||
name: "private_key_jwt unsupported alg HS256",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "private_key_jwt"
|
||||
c.ClientSecret = ""
|
||||
c.ClientAssertionPrivateKey = validPEM
|
||||
c.ClientAssertionKeyID = "k1"
|
||||
c.ClientAssertionAlg = "HS256"
|
||||
},
|
||||
wantErr: "is not supported",
|
||||
},
|
||||
{
|
||||
name: "unknown client auth method",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "weird"
|
||||
},
|
||||
wantErr: "is not supported",
|
||||
},
|
||||
{
|
||||
name: "client_secret_post with no secret",
|
||||
mutate: func(c *Config) {
|
||||
c.ClientAuthMethod = "client_secret_post"
|
||||
c.ClientSecret = ""
|
||||
},
|
||||
wantErr: "clientSecret is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg := base()
|
||||
tc.mutate(cfg)
|
||||
err := cfg.Validate()
|
||||
if tc.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.wantErr,
|
||||
"error must mention %q", tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue135_ConfigKeyPathLoadsFile verifies that buildClientAssertionSignerFromConfig
|
||||
// reads the PEM key from disk when ClientAssertionKeyPath is set.
|
||||
func TestIssue135_ConfigKeyPathLoadsFile(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
dir := t.TempDir()
|
||||
keyFile := dir + "/private.pem"
|
||||
require.NoError(t, os.WriteFile(keyFile, pemBytes, 0o600))
|
||||
|
||||
cfg := &Config{
|
||||
ClientAuthMethod: "private_key_jwt",
|
||||
ClientAssertionKeyPath: keyFile,
|
||||
ClientAssertionKeyID: "file-kid",
|
||||
ClientAssertionAlg: "RS256",
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err, "should load signer from key file")
|
||||
require.NotNil(t, signer)
|
||||
|
||||
// Confirm signer produces a valid JWT.
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "client-from-file")
|
||||
require.NoError(t, err)
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3, "should produce a 3-part JWT")
|
||||
}
|
||||
|
||||
// ── C. Wire-up — exchangeTokens ───────────────────────────────────────────────
|
||||
|
||||
// TestIssue135_AuthCodeExchangeUsesAssertion confirms that exchangeTokens sends
|
||||
// client_assertion + client_assertion_type instead of client_secret when a
|
||||
// ClientAssertionSigner is configured, and that the assertion JWT is valid.
|
||||
func TestIssue135_AuthCodeExchangeUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
var capturedBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body := make([]byte, r.ContentLength)
|
||||
_, _ = r.Body.Read(body)
|
||||
capturedBody = body
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Return a minimal token response so exchangeTokens doesn't error.
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at",
|
||||
IDToken: "it",
|
||||
RefreshToken: "rt",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "wire-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "wire-client",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: signer,
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err = oidc.exchangeTokens(context.Background(), "authorization_code", "code-x", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
form, err := url.ParseQuery(string(capturedBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
form.Get("client_assertion_type"), "client_assertion_type must be set")
|
||||
assertionJWT := form.Get("client_assertion")
|
||||
assert.NotEmpty(t, assertionJWT, "client_assertion must be present")
|
||||
assert.Empty(t, form.Get("client_secret"), "client_secret must not be sent when using assertion")
|
||||
assert.Equal(t, "wire-client", form.Get("client_id"))
|
||||
assert.Equal(t, "code-x", form.Get("code"))
|
||||
assert.Equal(t, "authorization_code", form.Get("grant_type"))
|
||||
|
||||
// Verify assertion JWT: header, claims, signature.
|
||||
parts := strings.Split(assertionJWT, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "RS256", hdr["alg"])
|
||||
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, "wire-client", clms["iss"])
|
||||
assert.Equal(t, "wire-client", clms["sub"])
|
||||
assert.Equal(t, server.URL, clms["aud"],
|
||||
"audience must be the tokenURL (RFC 7523 §3)")
|
||||
|
||||
// Verify signature with RSA public key.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
// TestIssue135_RefreshTokenUsesAssertion verifies that the refresh_token grant
|
||||
// type also sends client_assertion and the correct form fields.
|
||||
func TestIssue135_RefreshTokenUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "new-at",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rt-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "rt-client",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: signer,
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err = oidc.exchangeTokens(context.Background(), "refresh_token", "rt-y", "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "refresh_token", capturedForm.Get("grant_type"))
|
||||
assert.Equal(t, "rt-y", capturedForm.Get("refresh_token"))
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
capturedForm.Get("client_assertion_type"))
|
||||
assert.NotEmpty(t, capturedForm.Get("client_assertion"))
|
||||
assert.Empty(t, capturedForm.Get("client_secret"))
|
||||
}
|
||||
|
||||
// TestIssue135_BackcompatClientSecretPath confirms that exchangeTokens sends
|
||||
// client_secret and does NOT send client_assertion when clientAssertion is nil.
|
||||
func TestIssue135_BackcompatClientSecretPath(t *testing.T) {
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "legacy-client",
|
||||
clientSecret: "legacy-secret",
|
||||
tokenHTTPClient: server.Client(),
|
||||
clientAssertion: nil, // back-compat path
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bc", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "legacy-secret", capturedForm.Get("client_secret"),
|
||||
"client_secret must be sent on the classic path")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion"),
|
||||
"client_assertion must NOT be present on the classic path")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion_type"),
|
||||
"client_assertion_type must NOT be present on the classic path")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicAuth verifies that when clientAuthMethod is
|
||||
// "client_secret_basic", exchangeTokens sends an HTTP Basic Authorization
|
||||
// header carrying url-encoded client_id:client_secret per RFC 6749 §2.3.1,
|
||||
// and that neither client_id nor client_secret appears in the form body.
|
||||
func TestIssue135_ClientSecretBasicAuth(t *testing.T) {
|
||||
var capturedAuth string
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: "at-basic", TokenType: "Bearer", ExpiresIn: 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "basic-client",
|
||||
clientSecret: "basic-secret",
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
tokenHTTPClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bb", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, strings.HasPrefix(capturedAuth, "Basic "),
|
||||
"Authorization header must start with 'Basic ', got %q", capturedAuth)
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err, "Authorization payload must be valid base64")
|
||||
user, pass, ok := strings.Cut(string(raw), ":")
|
||||
require.True(t, ok, "Authorization payload must contain a single ':' separator")
|
||||
assert.Equal(t, "basic-client", user, "client_id should round-trip through QueryEscape")
|
||||
assert.Equal(t, "basic-secret", pass, "client_secret should round-trip through QueryEscape")
|
||||
|
||||
assert.Empty(t, capturedForm.Get("client_id"),
|
||||
"client_id must NOT be in the body when using client_secret_basic")
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must NOT be in the body when using client_secret_basic")
|
||||
assert.Empty(t, capturedForm.Get("client_assertion"),
|
||||
"client_assertion must NOT be present on the basic-auth path")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicURLEncodesReservedChars verifies that
|
||||
// credentials containing reserved characters (`:`, `+`, `/`, etc.) are
|
||||
// form-urlencoded before base64 per RFC 6749 §2.3.1, so the receiving
|
||||
// authorization server can decode them deterministically.
|
||||
func TestIssue135_ClientSecretBasicURLEncodesReservedChars(t *testing.T) {
|
||||
var capturedAuth string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(TokenResponse{AccessToken: "at", TokenType: "Bearer", ExpiresIn: 3600})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
const (
|
||||
clientID = "weird:id+1"
|
||||
clientSecret = "p@ss/word=&" //nolint:gosec // test fixture
|
||||
)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
tokenHTTPClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = server.URL
|
||||
|
||||
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "c", "https://app/cb", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err)
|
||||
|
||||
wantUser := url.QueryEscape(clientID)
|
||||
wantPass := url.QueryEscape(clientSecret)
|
||||
assert.Equal(t, wantUser+":"+wantPass, string(raw),
|
||||
"both halves must be form-urlencoded before the base64 step")
|
||||
}
|
||||
|
||||
// TestIssue135_ClientSecretBasicRevocation verifies that the revocation path
|
||||
// honors client_secret_basic identically to the token path.
|
||||
func TestIssue135_ClientSecretBasicRevocation(t *testing.T) {
|
||||
var capturedAuth string
|
||||
var capturedForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: "rev-basic",
|
||||
clientSecret: "rev-secret",
|
||||
clientAuthMethod: "client_secret_basic",
|
||||
httpClient: server.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
oidc.tokenURL = "https://idp.example.com/token"
|
||||
oidc.revocationURL = server.URL
|
||||
|
||||
require.NoError(t, oidc.RevokeTokenWithProvider("opaque-tok", "access_token"))
|
||||
|
||||
require.True(t, strings.HasPrefix(capturedAuth, "Basic "), "got %q", capturedAuth)
|
||||
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rev-basic:rev-secret", string(raw))
|
||||
|
||||
assert.Equal(t, "opaque-tok", capturedForm.Get("token"))
|
||||
assert.Equal(t, "access_token", capturedForm.Get("token_type_hint"))
|
||||
assert.Empty(t, capturedForm.Get("client_id"),
|
||||
"client_id must NOT be in body on Basic-auth revocation")
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must NOT be in body on Basic-auth revocation")
|
||||
}
|
||||
|
||||
// ── D. Wire-up — RevokeTokenWithProvider ────────────────────────────────────
|
||||
|
||||
// TestIssue135_RevocationUsesAssertion verifies that RevokeTokenWithProvider
|
||||
// sends client_assertion (not client_secret), and that the assertion's audience
|
||||
// is the tokenURL, not the revocationURL (per RFC 7523 §3).
|
||||
func TestIssue135_RevocationUsesAssertion(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
const (
|
||||
tokenEndpoint = "https://idp.example.com/token" // audience for assertion
|
||||
clientIDVal = "revoke-client"
|
||||
)
|
||||
|
||||
var capturedForm url.Values
|
||||
// Revocation endpoint — deliberate separate URL to confirm audience != revocationURL.
|
||||
revokeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.NoError(t, r.ParseForm())
|
||||
capturedForm = r.Form
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer revokeServer.Close()
|
||||
|
||||
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rev-kid")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
clientID: clientIDVal,
|
||||
clientAssertion: signer,
|
||||
httpClient: revokeServer.Client(),
|
||||
logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
// tokenURL drives assertion audience; revocationURL is where the POST goes.
|
||||
oidc.tokenURL = tokenEndpoint
|
||||
oidc.revocationURL = revokeServer.URL
|
||||
|
||||
err = oidc.RevokeTokenWithProvider("some-token", "refresh_token")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
capturedForm.Get("client_assertion_type"))
|
||||
assertionJWT := capturedForm.Get("client_assertion")
|
||||
assert.NotEmpty(t, assertionJWT)
|
||||
assert.Empty(t, capturedForm.Get("client_secret"),
|
||||
"client_secret must not appear in revocation request with assertion")
|
||||
|
||||
// Verify the assertion audience is tokenURL (not revocationURL).
|
||||
parts := strings.Split(assertionJWT, ".")
|
||||
require.Len(t, parts, 3)
|
||||
clms := decodeJSONPart(t, parts[1])
|
||||
assert.Equal(t, tokenEndpoint, clms["aud"],
|
||||
"assertion audience must be tokenURL, not revocationURL")
|
||||
|
||||
// Sanity-check cryptographic validity.
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
// ── E. End-to-end via buildClientAssertionSignerFromConfig ───────────────────
|
||||
|
||||
// TestIssue135_BuildSignerFromInlineConfig confirms that the full config→signer
|
||||
// pipeline works for an ES256 key specified inline in the Config struct.
|
||||
func TestIssue135_BuildSignerFromInlineConfig(t *testing.T) {
|
||||
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
pemBytes := encodeECPKCS8(t, ecKey)
|
||||
|
||||
cfg := &Config{
|
||||
ClientAuthMethod: "private_key_jwt",
|
||||
ClientAssertionPrivateKey: string(pemBytes),
|
||||
ClientAssertionKeyID: "inline-ec-kid",
|
||||
ClientAssertionAlg: "ES256",
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signer)
|
||||
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "inline-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "ES256", hdr["alg"])
|
||||
assert.Equal(t, "inline-ec-kid", hdr["kid"])
|
||||
|
||||
// Verify the EC signature.
|
||||
byteLen := (elliptic.P256().Params().BitSize + 7) / 8
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
require.Len(t, sigBytes, 2*byteLen)
|
||||
|
||||
r := new(big.Int).SetBytes(sigBytes[:byteLen])
|
||||
s := new(big.Int).SetBytes(sigBytes[byteLen:])
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
assert.True(t, ecdsa.Verify(&ecKey.PublicKey, digest, r, s))
|
||||
}
|
||||
|
||||
// TestIssue135_BuildSignerDefaultsToRS256 verifies that an empty
|
||||
// ClientAssertionAlg defaults to RS256.
|
||||
func TestIssue135_BuildSignerDefaultsToRS256(t *testing.T) {
|
||||
rsaKey := genRSAKey(t, 2048)
|
||||
pemBytes := encodeRSAPKCS8(t, rsaKey)
|
||||
|
||||
cfg := &Config{
|
||||
ClientAssertionPrivateKey: string(pemBytes),
|
||||
ClientAssertionKeyID: "default-alg-kid",
|
||||
ClientAssertionAlg: "", // intentionally empty
|
||||
}
|
||||
|
||||
signer, err := buildClientAssertionSignerFromConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "default-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, "RS256", hdr["alg"], "empty alg must default to RS256")
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
// genRSAKey generates an RSA key of the given bit size, failing the test on error.
|
||||
func genRSAKey(t *testing.T, bits int) *rsa.PrivateKey {
|
||||
t.Helper()
|
||||
k, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
require.NoError(t, err)
|
||||
return k
|
||||
}
|
||||
|
||||
// encodeRSAPKCS8 marshals an RSA key as PKCS#8 PEM ("PRIVATE KEY").
|
||||
func encodeRSAPKCS8(t *testing.T, key *rsa.PrivateKey) []byte {
|
||||
t.Helper()
|
||||
der, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
|
||||
}
|
||||
|
||||
// encodeECPKCS8 marshals an EC key as PKCS#8 PEM ("PRIVATE KEY").
|
||||
func encodeECPKCS8(t *testing.T, key *ecdsa.PrivateKey) []byte {
|
||||
t.Helper()
|
||||
der, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
|
||||
}
|
||||
|
||||
// decodeJSONPart base64url-decodes a JWT part and parses it as a JSON object.
|
||||
func decodeJSONPart(t *testing.T, b64url string) map[string]any {
|
||||
t.Helper()
|
||||
raw, err := base64.RawURLEncoding.DecodeString(b64url)
|
||||
require.NoError(t, err, "base64url decode of JWT part failed")
|
||||
var m map[string]any
|
||||
require.NoError(t, json.Unmarshal(raw, &m), "JSON unmarshal of JWT part failed")
|
||||
return m
|
||||
}
|
||||
|
||||
// sha256SumBytes returns the SHA-256 digest of b as a byte slice.
|
||||
func sha256SumBytes(b []byte) []byte {
|
||||
h := sha256.Sum256(b)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
// assertValidRSAJWT signs a JWT with signer and verifies the RS256 signature
|
||||
// against the given RSA public key. Used by PEM variant tests.
|
||||
func assertValidRSAJWT(t *testing.T, key *rsa.PrivateKey, signer *ClientAssertionSigner, alg string) {
|
||||
t.Helper()
|
||||
jwtStr, err := signer.Sign("https://example.com/token", "pem-client")
|
||||
require.NoError(t, err)
|
||||
|
||||
parts := strings.Split(jwtStr, ".")
|
||||
require.Len(t, parts, 3)
|
||||
|
||||
hdr := decodeJSONPart(t, parts[0])
|
||||
assert.Equal(t, alg, hdr["alg"])
|
||||
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
require.NoError(t, err)
|
||||
|
||||
sigInput := parts[0] + "." + parts[1]
|
||||
digest := sha256SumBytes([]byte(sigInput))
|
||||
assert.NoError(t, rsa.VerifyPKCS1v15(&key.PublicKey, crypto.SHA256, digest, sigBytes))
|
||||
}
|
||||
|
||||
@@ -478,11 +478,10 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
|
||||
// Test 3: Rate limiting
|
||||
t.Run("RateLimiting", func(t *testing.T) {
|
||||
// Reset circuit breaker to closed state for this test
|
||||
coordinator.circuitBreaker.mutex.Lock()
|
||||
// Reset circuit breaker to closed state for this test. All fields are
|
||||
// atomic so we don't need any mutex.
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0)
|
||||
coordinator.circuitBreaker.mutex.Unlock()
|
||||
|
||||
// Temporarily increase circuit breaker threshold to not interfere
|
||||
oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures
|
||||
@@ -525,9 +524,11 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
time.Sleep(config.CleanupInterval * 3)
|
||||
|
||||
// Old sessions should be cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
count := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
count := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
// Should have fewer sessions after cleanup
|
||||
if count > 10 {
|
||||
|
||||
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
@@ -18,6 +19,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// parsedKeysSuffix marks the parallel UniversalCache entry that stores
|
||||
// pre-parsed public keys for a given JWKS URL.
|
||||
const parsedKeysSuffix = ":parsed"
|
||||
|
||||
// parsedJWKS holds keys decoded from a JWKSet, indexed by kid. Storing the
|
||||
// already-parsed crypto.PublicKey avoids re-running the DER/PEM round trip
|
||||
// on every JWT verification — a costly operation under the yaegi interpreter
|
||||
// that hosts Traefik plugins.
|
||||
type parsedJWKS struct {
|
||||
keys map[string]crypto.PublicKey
|
||||
}
|
||||
|
||||
// JWK represents a JSON Web Key as defined in RFC 7517.
|
||||
// It can represent different key types including RSA, EC, and symmetric keys.
|
||||
type JWK struct {
|
||||
@@ -40,15 +53,32 @@ type JWKSet struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache provides thread-safe caching of JWKS using UniversalCache
|
||||
// JWKCache provides thread-safe caching of JWKS using UniversalCache.
|
||||
//
|
||||
// inflightFetches deduplicates concurrent fetches for the same JWKS URL.
|
||||
// It replaces a global sync.RWMutex that was previously held for the entire
|
||||
// HTTP round-trip in GetJWKS: on a cold cache (cold pod, JWK rotation, brief
|
||||
// network blip) every concurrent request piled up on that single Lock(), and
|
||||
// under Yaegi each Lock acquisition costs 10-50ms of interpreter-dispatch
|
||||
// overhead. The singleflight pattern keeps the cold-cache cost O(1) HTTP
|
||||
// fetch regardless of how many requests are waiting.
|
||||
type JWKCache struct {
|
||||
cache *UniversalCache
|
||||
mutex sync.RWMutex
|
||||
cache *UniversalCache
|
||||
inflightFetches sync.Map // map[jwksURL string]*jwksFetch
|
||||
}
|
||||
|
||||
// jwksFetch represents an in-flight JWKS fetch. Done is closed when the fetch
|
||||
// completes; jwks and err carry the result (one of them is set, never both).
|
||||
type jwksFetch struct {
|
||||
done chan struct{}
|
||||
jwks *JWKSet
|
||||
err error
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the contract for JWK caching implementations.
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error)
|
||||
Cleanup()
|
||||
Close()
|
||||
}
|
||||
@@ -62,38 +92,146 @@ func NewJWKCache() *JWKCache {
|
||||
}
|
||||
|
||||
// GetJWKS retrieves JWKS from cache or fetches from the remote URL if not cached.
|
||||
//
|
||||
// The entry is stored locally only via SetLocal/GetLocal. Going through a
|
||||
// distributed backend defeats the cache: JSON round-tripping turns *JWKSet
|
||||
// into map[string]interface{}, the type assertion below fails, and every
|
||||
// request refetches from the upstream. JWK rotation is rare and a per-replica
|
||||
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Check cache first
|
||||
if cachedValue, found := c.cache.Get(jwksURL); found {
|
||||
// Fast path: cache hit.
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// Singleflight: dedupe concurrent fetches per URL key. The first arrival
|
||||
// performs the HTTP fetch; any later arrival for the same URL waits on
|
||||
// its done channel and shares the result. No global lock is held during
|
||||
// the fetch.
|
||||
candidate := &jwksFetch{done: make(chan struct{})}
|
||||
if existing, loaded := c.inflightFetches.LoadOrStore(jwksURL, candidate); loaded {
|
||||
f, _ := existing.(*jwksFetch)
|
||||
select {
|
||||
case <-f.done:
|
||||
return f.jwks, f.err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if cachedValue, found := c.cache.Get(jwksURL); found {
|
||||
// We're the leader. Make absolutely sure the result fields and the
|
||||
// in-flight map entry are cleaned up before any waiter unblocks.
|
||||
defer func() {
|
||||
c.inflightFetches.Delete(jwksURL)
|
||||
close(candidate.done)
|
||||
}()
|
||||
|
||||
// Re-check the cache in case a concurrent fetch completed between our
|
||||
// initial miss and our LoadOrStore win.
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
candidate.jwks = jwks
|
||||
return jwks, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch from URL
|
||||
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
candidate.err = err
|
||||
return nil, err
|
||||
}
|
||||
if len(jwks.Keys) == 0 {
|
||||
candidate.err = fmt.Errorf("JWKS response contains no keys")
|
||||
return nil, candidate.err
|
||||
}
|
||||
|
||||
// Cache for 1 hour.
|
||||
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
candidate.jwks = jwks
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// GetPublicKey returns the parsed public key for a given kid, fetching and
|
||||
// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is
|
||||
// stored alongside the raw JWKSet under a sibling cache key with the same
|
||||
// 1-hour TTL, so both invalidate together when the upstream JWKS rotates.
|
||||
//
|
||||
// parsedJWKS is stored locally only (SetLocal/GetLocal). Its values are
|
||||
// crypto.PublicKey interfaces wrapping *rsa.PublicKey/*ecdsa.PublicKey,
|
||||
// which contain *big.Int that marshals to a hundreds-digit JSON number.
|
||||
// On a distributed backend round-trip, json.Unmarshal into interface{} would
|
||||
// try to fit that into float64 and fail with UnmarshalTypeError. Under yaegi
|
||||
// the unexported parsedJWKS.keys field is exposed via an X-prefixed name on
|
||||
// Marshal, leaking the modulus into the cached payload (issue #134).
|
||||
func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
parsedKey := jwksURL + parsedKeysSuffix
|
||||
if v, found := c.cache.GetLocal(parsedKey); found {
|
||||
if pj, ok := v.(*parsedJWKS); ok {
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jwks, err := c.GetJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(jwks.Keys) == 0 {
|
||||
return nil, fmt.Errorf("JWKS response contains no keys")
|
||||
pj := buildParsedJWKS(jwks)
|
||||
_ = c.cache.SetLocal(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
// Cache for 1 hour
|
||||
_ = c.cache.Set(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
return jwks, nil
|
||||
// buildParsedJWKS pre-parses every JWK in the set into the matching
|
||||
// crypto.PublicKey, indexed by kid. Errors on individual keys are skipped so
|
||||
// a single bad key does not block the rest of the keyset.
|
||||
func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
|
||||
out := make(map[string]crypto.PublicKey, len(jwks.Keys))
|
||||
for i := range jwks.Keys {
|
||||
k := &jwks.Keys[i]
|
||||
if k.Kid == "" {
|
||||
continue
|
||||
}
|
||||
// Skip keys that are not intended for signature verification.
|
||||
if k.Use != "" && k.Use != "sig" {
|
||||
continue
|
||||
}
|
||||
if len(k.KeyOps) > 0 {
|
||||
hasVerify := false
|
||||
for _, op := range k.KeyOps {
|
||||
if op == "verify" {
|
||||
hasVerify = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasVerify {
|
||||
continue
|
||||
}
|
||||
}
|
||||
var pub crypto.PublicKey
|
||||
var err error
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
pub, err = k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
pub, err = k.ToECDSAPublicKey()
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out[k.Kid] = pub
|
||||
}
|
||||
return &parsedJWKS{keys: out}
|
||||
}
|
||||
|
||||
// Cleanup is a no-op as cleanup is handled by UniversalCache
|
||||
@@ -120,11 +258,11 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
|
||||
defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body) // Safe to ignore: reading error body for diagnostics
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024)) // Safe to ignore: reading error body for diagnostics
|
||||
return nil, fmt.Errorf("JWKS fetch failed with status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading JWKS response: %w", err)
|
||||
}
|
||||
@@ -213,9 +351,9 @@ func (jwk *JWK) ToECDSAPublicKey() (*ecdsa.PublicKey, error) {
|
||||
// GetKey finds a key by its ID (kid) in the JWKSet.
|
||||
// Returns nil if no key with the given ID is found.
|
||||
func (jwks *JWKSet) GetKey(kid string) *JWK {
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
return &key
|
||||
for i := range jwks.Keys {
|
||||
if jwks.Keys[i].Kid == kid {
|
||||
return &jwks.Keys[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -120,7 +120,7 @@ func getReplayCacheStats() (size int, maxSize int) {
|
||||
// Parameters:
|
||||
// - ctx: Parent context for cancellation
|
||||
// - logger: Logger for debug output (can be nil)
|
||||
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
|
||||
func startReplayCacheCleanup(_ context.Context, logger *Logger) {
|
||||
registry := GetGlobalTaskRegistry()
|
||||
|
||||
// Define the cleanup task function
|
||||
@@ -528,6 +528,21 @@ func verifyNotBefore(notBefore float64) error {
|
||||
// - An error if the key parsing fails, the algorithm is unsupported,
|
||||
// or the signature verification fails
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
return verifySignatureWithKey(tokenString, pubKey, alg)
|
||||
}
|
||||
|
||||
// verifySignatureWithKey verifies a JWT signature using an already-parsed
|
||||
// public key, skipping the PEM-encode/decode round trip that verifySignature
|
||||
// performs. This is the hot path used by VerifyJWTSignatureAndClaims.
|
||||
func verifySignatureWithKey(tokenString string, pubKey crypto.PublicKey, alg string) error {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
@@ -537,14 +552,6 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
var hashFunc crypto.Hash
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
|
||||
@@ -134,8 +134,11 @@ func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http
|
||||
expectedIssuer := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if iss != "" && iss != expectedIssuer {
|
||||
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
|
||||
// Require a matching issuer. An empty iss must be rejected too: accepting a
|
||||
// missing issuer would let an unauthenticated attacker force-logout any
|
||||
// session whose sid is known by simply omitting iss.
|
||||
if iss == "" || iss != expectedIssuer {
|
||||
t.logger.Errorf("Front-channel logout: issuer validation failed: got %q, expected %q", iss, expectedIssuer)
|
||||
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -316,9 +319,9 @@ func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) e
|
||||
|
||||
// Find the matching key in JWKS
|
||||
var matchingKey *JWK
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
matchingKey = &key
|
||||
for i := range jwks.Keys {
|
||||
if jwks.Keys[i].Kid == kid {
|
||||
matchingKey = &jwks.Keys[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
+71
-27
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
@@ -124,10 +125,14 @@ func TestFrontchannelLogoutBasic(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Valid front-channel logout without issuer",
|
||||
// Front-channel logout MUST carry a matching issuer. A request
|
||||
// omitting iss is rejected so an unauthenticated attacker cannot
|
||||
// force-logout a session whose sid is known by simply leaving iss
|
||||
// out (audit rank 30).
|
||||
name: "Missing issuer is rejected",
|
||||
method: http.MethodGet,
|
||||
queryParams: map[string]string{"sid": "session456"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -406,17 +411,17 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) {
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
logger: NewLogger("debug"),
|
||||
enableBackchannelLogout: true,
|
||||
backchannelLogoutPath: "/backchannel-logout",
|
||||
sessionInvalidationCache: mockCache,
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
logoutURLPath: "/logout",
|
||||
next: nextHandler,
|
||||
logger: NewLogger("debug"),
|
||||
enableBackchannelLogout: true,
|
||||
backchannelLogoutPath: "/backchannel-logout",
|
||||
sessionInvalidationCache: mockCache,
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
@@ -448,22 +453,23 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
logger: NewLogger("debug"),
|
||||
enableFrontchannelLogout: true,
|
||||
frontchannelLogoutPath: "/frontchannel-logout",
|
||||
sessionInvalidationCache: mockCache,
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
logoutURLPath: "/logout",
|
||||
next: nextHandler,
|
||||
logger: NewLogger("debug"),
|
||||
enableFrontchannelLogout: true,
|
||||
frontchannelLogoutPath: "/frontchannel-logout",
|
||||
sessionInvalidationCache: mockCache,
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Request to front-channel logout path with valid sid should succeed
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session", nil)
|
||||
// Request to front-channel logout path with valid sid + matching issuer
|
||||
// should succeed. The issuer is now required (audit rank 30), so supply it.
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session&iss=https://provider.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
@@ -639,6 +645,26 @@ func (m *mockJWKCacheForLogout) GetJWKS(ctx context.Context, jwksURL string, htt
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockJWKCacheForLogout) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
jwks, err := m.GetJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range jwks.Keys {
|
||||
k := &jwks.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (m *mockJWKCacheForLogout) Clear() {}
|
||||
func (m *mockJWKCacheForLogout) Cleanup() {}
|
||||
func (m *mockJWKCacheForLogout) Close() {}
|
||||
@@ -755,6 +781,22 @@ func (s *staticJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient
|
||||
return s.jwks, nil
|
||||
}
|
||||
|
||||
func (s *staticJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
for i := range s.jwks.Keys {
|
||||
k := &s.jwks.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (s *staticJWKCache) Clear() {}
|
||||
func (s *staticJWKCache) Cleanup() {}
|
||||
func (s *staticJWKCache) Close() {}
|
||||
@@ -1395,7 +1437,9 @@ func TestFrontchannelLogoutCacheControl(t *testing.T) {
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123", nil)
|
||||
// Issuer is now required (audit rank 30); supply a matching one so the
|
||||
// successful-logout cache headers can be asserted.
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123&iss=https://provider.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.handleFrontchannelLogout(rw, req)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
telemetry "github.com/lukaszraczylo/oss-telemetry"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@@ -23,6 +25,11 @@ const (
|
||||
ConstSessionTimeout = 86400
|
||||
)
|
||||
|
||||
// telemetryStartupOnce keeps the anonymous "plugin loaded" ping to one per
|
||||
// process. Traefik calls New once per route that uses the plugin; oss-telemetry
|
||||
// does not deduplicate client-side (the server does), so the gate stays here.
|
||||
var telemetryStartupOnce sync.Once
|
||||
|
||||
// isTestMode detects if the code is running in a test environment.
|
||||
func isTestMode() bool {
|
||||
if os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS") == "1" {
|
||||
@@ -89,6 +96,13 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
// - The configured TraefikOidc handler ready to process requests.
|
||||
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
telemetryStartupOnce.Do(func() {
|
||||
// Only stamped release builds phone home; dev/local/test builds keep the
|
||||
// devPluginVersion sentinel (see version.go) and stay silent.
|
||||
if traefikoidcPluginVersion != devPluginVersion {
|
||||
telemetry.Send("traefikoidc", traefikoidcPluginVersion)
|
||||
}
|
||||
})
|
||||
return NewWithContext(ctx, config, next, name)
|
||||
}
|
||||
|
||||
@@ -99,26 +113,40 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
config = CreateConfig()
|
||||
}
|
||||
|
||||
if config.SessionEncryptionKey == "" {
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
logger := NewLogger(config.LogLevel)
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
logger.Infof("Session encryption key is too short; using default key for analyzer")
|
||||
} else {
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
|
||||
// Fail closed on invalid configuration. Validate() enforces the security
|
||||
// constraints (required fields, HTTPS-only URLs, key length, excludedURLs
|
||||
// safety, rate-limit floor, audience format, ...) that were previously
|
||||
// unenforced because this constructor never called it. Crucially it rejects
|
||||
// an empty or too-short SessionEncryptionKey instead of silently
|
||||
// substituting a public hardcoded key, which would let an attacker forge
|
||||
// any session. Traefik's yaegi plugin analyzer supplies a valid key via
|
||||
// .traefik.yml testData, so it passes; only misconfigured deployments fail.
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
// Setup HTTP client
|
||||
caPool, err := config.loadCACertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load CA certificates: %w", err)
|
||||
}
|
||||
if config.InsecureSkipVerify {
|
||||
logger.Errorf("SECURITY WARNING: InsecureSkipVerify is enabled for the OIDC provider. TLS certificate verification is DISABLED. Do not use in production.")
|
||||
}
|
||||
var httpClient *http.Client
|
||||
if config.HTTPClient != nil {
|
||||
httpClient = config.HTTPClient
|
||||
} else {
|
||||
httpClient = CreateDefaultHTTPClient()
|
||||
defaultCfg := DefaultHTTPClientConfig()
|
||||
defaultCfg.RootCAs = caPool
|
||||
defaultCfg.InsecureSkipVerify = config.InsecureSkipVerify
|
||||
httpClient = CreatePooledHTTPClient(defaultCfg)
|
||||
}
|
||||
tokenCfg := TokenHTTPClientConfig()
|
||||
tokenCfg.RootCAs = caPool
|
||||
tokenCfg.InsecureSkipVerify = config.InsecureSkipVerify
|
||||
tokenHTTPClient := CreatePooledHTTPClient(tokenCfg)
|
||||
goroutineWG := &sync.WaitGroup{}
|
||||
cacheManager := GetGlobalCacheManagerWithConfig(goroutineWG, config)
|
||||
|
||||
@@ -155,6 +183,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
introspectionCache: cacheManager.GetSharedIntrospectionCache(), // Cache for introspection results
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
clientAuthMethod: func() string {
|
||||
if config.ClientAuthMethod != "" {
|
||||
return config.ClientAuthMethod
|
||||
}
|
||||
return "client_secret_post"
|
||||
}(),
|
||||
audience: func() string {
|
||||
if config.Audience != "" {
|
||||
return config.Audience
|
||||
@@ -181,6 +215,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}(),
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
enablePKCE: config.EnablePKCE,
|
||||
extraAuthParams: config.ExtraAuthParams,
|
||||
overrideScopes: config.OverrideScopes,
|
||||
strictAudienceValidation: config.StrictAudienceValidation,
|
||||
allowOpaqueTokens: config.AllowOpaqueTokens,
|
||||
@@ -199,9 +234,9 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: cacheManager.GetSharedTokenCache(),
|
||||
httpClient: httpClient,
|
||||
tokenHTTPClient: CreateTokenHTTPClient(),
|
||||
tokenHTTPClient: tokenHTTPClient,
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedUserDomains: createCaseInsensitiveStringMap(config.AllowedUserDomains),
|
||||
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
@@ -212,21 +247,70 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return 60 * time.Second
|
||||
}(),
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
cancelFunc: cancelFunc,
|
||||
suppressDiagnosticLogs: isTestMode(),
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
dcrConfig: config.DynamicClientRegistration,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
minimalHeaders: config.MinimalHeaders,
|
||||
enableBackchannelLogout: config.EnableBackchannelLogout,
|
||||
enableFrontchannelLogout: config.EnableFrontchannelLogout,
|
||||
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
|
||||
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
|
||||
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
|
||||
maxRefreshTokenAge: func() time.Duration {
|
||||
// 0 (or unset) disables the heuristic; negative is rejected by Validate.
|
||||
if config.MaxRefreshTokenAgeSeconds > 0 {
|
||||
return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second
|
||||
}
|
||||
return 0
|
||||
}(),
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
cancelFunc: cancelFunc,
|
||||
suppressDiagnosticLogs: isTestMode(),
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
dcrConfig: config.DynamicClientRegistration,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
minimalHeaders: config.MinimalHeaders,
|
||||
stripAuthCookies: config.StripAuthCookies,
|
||||
enableBackchannelLogout: config.EnableBackchannelLogout,
|
||||
enableFrontchannelLogout: config.EnableFrontchannelLogout,
|
||||
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
|
||||
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
|
||||
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
|
||||
refreshResultCache: cacheManager.GetSharedRefreshResultCache(),
|
||||
enableBearerAuth: config.EnableBearerAuth,
|
||||
stripAuthorizationHeader: config.StripAuthorizationHeader,
|
||||
bearerEmitWWWAuthenticate: config.BearerEmitWWWAuthenticate,
|
||||
bearerOverridesCookie: config.BearerOverridesCookie,
|
||||
bearerIdentifierClaim: func() string {
|
||||
if config.BearerIdentifierClaim != "" {
|
||||
return config.BearerIdentifierClaim
|
||||
}
|
||||
return "sub"
|
||||
}(),
|
||||
maxIdentifierLength: func() int {
|
||||
if config.MaxIdentifierLength > 0 {
|
||||
return config.MaxIdentifierLength
|
||||
}
|
||||
return 256
|
||||
}(),
|
||||
maxTokenAge: func() time.Duration {
|
||||
if config.MaxTokenAgeSeconds > 0 {
|
||||
return time.Duration(config.MaxTokenAgeSeconds) * time.Second
|
||||
}
|
||||
return 24 * time.Hour
|
||||
}(),
|
||||
bearerFailureThreshold: func() int {
|
||||
if config.BearerFailureThreshold > 0 {
|
||||
return config.BearerFailureThreshold
|
||||
}
|
||||
return 20
|
||||
}(),
|
||||
bearerFailureWindow: func() time.Duration {
|
||||
if config.BearerFailureWindowSeconds > 0 {
|
||||
return time.Duration(config.BearerFailureWindowSeconds) * time.Second
|
||||
}
|
||||
return 60 * time.Second
|
||||
}(),
|
||||
bearerFailurePenalty: func() time.Duration {
|
||||
if config.BearerFailurePenaltySeconds > 0 {
|
||||
return time.Duration(config.BearerFailurePenaltySeconds) * time.Second
|
||||
}
|
||||
return 60 * time.Second
|
||||
}(),
|
||||
}
|
||||
|
||||
// Log audience configuration
|
||||
@@ -236,15 +320,59 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
t.logger.Debugf("No custom audience specified, using clientID as audience: %s", t.clientID)
|
||||
}
|
||||
|
||||
// Bearer-auth startup validation. The bearer path is M2M-only and demands
|
||||
// a non-default audience so tokens issued for a different resource cannot
|
||||
// be replayed against this service. The BearerIdentifierClaim guard blocks
|
||||
// the `email` claim explicitly — without email_verified enforcement (out of
|
||||
// scope for M2M), trusting email is a spoofing vector for federated IdPs.
|
||||
// See spec §7.9 / §13.
|
||||
if config.EnableBearerAuth {
|
||||
if config.Audience == "" {
|
||||
cancelFunc()
|
||||
return nil, fmt.Errorf("EnableBearerAuth=true requires Audience to be set explicitly (cannot default to clientID — that path accepts ID tokens)")
|
||||
}
|
||||
if t.bearerIdentifierClaim == "email" {
|
||||
cancelFunc()
|
||||
return nil, fmt.Errorf("enableBearerAuth=true with bearerIdentifierClaim=%q is rejected: email-based identity without email_verified enforcement is a spoofing vector for federated IdPs (use \"sub\" or a custom claim; cookie-path userIdentifierClaim is unaffected)", t.bearerIdentifierClaim)
|
||||
}
|
||||
if !config.StrictAudienceValidation {
|
||||
t.logger.Infof("EnableBearerAuth=true with StrictAudienceValidation=false: recommend enabling strict audience validation for hardening")
|
||||
}
|
||||
t.bearerFailureTracker = newBearerFailureTracker(
|
||||
t.bearerFailureThreshold, t.bearerFailureWindow, t.bearerFailurePenalty,
|
||||
)
|
||||
t.logger.Infof("Bearer-token auth enabled: audience=%q identifierClaim=%q stripAuthz=%t bearerOverridesCookie=%t maxTokenAge=%s",
|
||||
config.Audience, t.bearerIdentifierClaim, t.stripAuthorizationHeader, t.bearerOverridesCookie, t.maxTokenAge)
|
||||
}
|
||||
|
||||
// Convert sessionMaxAge from seconds to duration (0 will use default 24 hours)
|
||||
sessionMaxAge := time.Duration(config.SessionMaxAge) * time.Second
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger) // Safe to ignore: session manager creation with fallback to defaults
|
||||
sessionManager, err := NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
return nil, fmt.Errorf("failed to create session manager: %w", err)
|
||||
}
|
||||
t.sessionManager = sessionManager
|
||||
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
|
||||
|
||||
// Initialize token resilience manager with default configuration
|
||||
tokenResilienceConfig := DefaultTokenResilienceConfig()
|
||||
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
|
||||
|
||||
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
|
||||
// call, preventing the thundering herd that yields invalid_grant when the IdP
|
||||
// rotates refresh tokens (Zitadel/Authentik default).
|
||||
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
|
||||
|
||||
if config.ClientAuthMethod == "private_key_jwt" {
|
||||
signer, err := buildClientAssertionSignerFromConfig(config)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
return nil, fmt.Errorf("failed to build client assertion signer: %w", err)
|
||||
}
|
||||
t.clientAssertion = signer
|
||||
}
|
||||
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
@@ -292,17 +420,22 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
|
||||
startReplayCacheCleanup(pluginCtx, logger)
|
||||
|
||||
// Start memory monitoring for leak detection and performance insights
|
||||
// Start memory monitoring for leak detection and performance insights.
|
||||
// The interval is clamped to MinMemoryMonitorInterval (30s) inside
|
||||
// StartMonitoring; tests that need deterministic sampling should call
|
||||
// MemoryMonitor.Refresh() directly instead of waiting on a fast ticker.
|
||||
memoryMonitor := GetGlobalMemoryMonitor()
|
||||
monitorInterval := 60 * time.Second
|
||||
if isTestMode() {
|
||||
monitorInterval = 100 * time.Millisecond // Fast interval for tests
|
||||
}
|
||||
memoryMonitor.StartMonitoring(pluginCtx, monitorInterval)
|
||||
memoryMonitor.StartMonitoring(pluginCtx, DefaultMemoryMonitorInterval)
|
||||
logger.Debug("Started global memory monitoring")
|
||||
|
||||
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
|
||||
|
||||
// Log callback URL configuration to help diagnose redirect loop issues.
|
||||
// If callbackURL is a full URL instead of a path, the callback matching
|
||||
// in ServeHTTP will silently fail because req.URL.Path is compared directly.
|
||||
logger.Debugf("TraefikOidc.New: callbackURL (redirURLPath) configured as: %q", t.redirURLPath)
|
||||
logger.Debugf("TraefikOidc.New: logoutURLPath configured as: %q", t.logoutURLPath)
|
||||
|
||||
t.providerURL = config.ProviderURL
|
||||
|
||||
// Use singleton resource manager for metadata initialization
|
||||
@@ -310,6 +443,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
|
||||
// Add reference for this instance
|
||||
rm.AddReference(name)
|
||||
registerLiveInstance()
|
||||
|
||||
// Initialize metadata in a goroutine with proper tracking
|
||||
if t.goroutineWG != nil {
|
||||
@@ -387,13 +521,58 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
// Parameters:
|
||||
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
|
||||
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
// SSRF defense (audit ranks 3 & 4): a discovery document is attacker-
|
||||
// influenced when the provider or its TLS is compromised. Reject any
|
||||
// discovered endpoint pointed at a blocked address before the plugin issues
|
||||
// outbound requests to it, so it can never be used to reach the cloud
|
||||
// metadata service or an internal host.
|
||||
allowLoopback := false
|
||||
if pu, err := url.Parse(t.providerURL); err == nil {
|
||||
allowLoopback = isLoopbackHost(pu.Hostname())
|
||||
}
|
||||
sanitize := func(name, raw string) string {
|
||||
if err := t.validateDiscoveredEndpoint(raw, allowLoopback); err != nil {
|
||||
t.logger.Errorf("Ignoring discovered %s endpoint %q: %v", name, raw, err)
|
||||
return ""
|
||||
}
|
||||
return raw
|
||||
}
|
||||
metadata.JWKSURL = sanitize("jwks_uri", metadata.JWKSURL)
|
||||
metadata.AuthURL = sanitize("authorization", metadata.AuthURL)
|
||||
metadata.TokenURL = sanitize("token", metadata.TokenURL)
|
||||
metadata.RevokeURL = sanitize("revocation", metadata.RevokeURL)
|
||||
metadata.EndSessionURL = sanitize("end_session", metadata.EndSessionURL)
|
||||
metadata.RegistrationURL = sanitize("registration", metadata.RegistrationURL)
|
||||
metadata.IntrospectionURL = sanitize("introspection", metadata.IntrospectionURL)
|
||||
// The introspection request authenticates with the client secret via HTTP
|
||||
// Basic, so the endpoint must live on the same host as the operator-
|
||||
// configured provider; otherwise a poisoned discovery document could
|
||||
// exfiltrate the client secret to an attacker-controlled host.
|
||||
if metadata.IntrospectionURL != "" && t.providerURL != "" && !sameHost(metadata.IntrospectionURL, t.providerURL) {
|
||||
t.logger.Errorf("Ignoring introspection endpoint %q: host does not match configured providerURL", metadata.IntrospectionURL)
|
||||
metadata.IntrospectionURL = ""
|
||||
}
|
||||
|
||||
// Pin the discovered issuer to the operator-configured provider host. The
|
||||
// issuer is the trust anchor for JWT issuer validation, so a poisoned
|
||||
// discovery document advertising an attacker-chosen issuer must never be
|
||||
// stored. Real providers (Google, Azure, Keycloak, Okta, Auth0) keep the
|
||||
// issuer on the same host as the configured providerURL. On mismatch, leave
|
||||
// issuerURL empty/unchanged so downstream issuer validation fails closed
|
||||
// rather than trusting the attacker-chosen value.
|
||||
discoveredIssuer := metadata.Issuer
|
||||
if discoveredIssuer != "" && t.providerURL != "" && !sameHost(discoveredIssuer, t.providerURL) {
|
||||
t.logger.Errorf("Ignoring discovered issuer %q: host does not match configured providerURL", discoveredIssuer)
|
||||
discoveredIssuer = ""
|
||||
}
|
||||
|
||||
t.metadataMu.Lock()
|
||||
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.issuerURL = discoveredIssuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662)
|
||||
@@ -403,6 +582,19 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
introspectionURL := t.introspectionURL
|
||||
registrationURL := t.registrationURL
|
||||
|
||||
// Publish the read-mostly URL bundle atomically. Hot-path readers Load
|
||||
// this directly instead of acquiring metadataMu.RLock per request.
|
||||
t.metadataSnapshot.Store(&MetadataSnapshot{
|
||||
IssuerURL: discoveredIssuer,
|
||||
JWKSURL: metadata.JWKSURL,
|
||||
TokenURL: metadata.TokenURL,
|
||||
AuthURL: metadata.AuthURL,
|
||||
RevocationURL: metadata.RevokeURL,
|
||||
EndSessionURL: metadata.EndSessionURL,
|
||||
IntrospectionURL: metadata.IntrospectionURL,
|
||||
RegistrationURL: metadata.RegistrationURL,
|
||||
})
|
||||
|
||||
t.metadataMu.Unlock()
|
||||
|
||||
// Log introspection endpoint availability for opaque token support
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config Marshalling Tests
|
||||
// Config Marshaling Tests
|
||||
|
||||
func TestConfig_MarshalJSON(t *testing.T) {
|
||||
config := &Config{
|
||||
|
||||
@@ -194,6 +194,7 @@ func TestGoroutineLeakPrevention_MultipleInstances(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(ctx, nil, config, "test")
|
||||
if err != nil {
|
||||
@@ -322,6 +323,7 @@ func TestGoroutineLeakPrevention_BackgroundTaskCleanup(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(ctx, nil, config, "test")
|
||||
if err != nil {
|
||||
|
||||
+70
-68
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -25,38 +26,47 @@ func TestInitializeMetadata(t *testing.T) {
|
||||
name: "successful metadata initialization",
|
||||
providerURL: "",
|
||||
setupMock: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer must share the host with providerURL (the httptest
|
||||
// server), otherwise the discovery doc is rejected as poisoned
|
||||
// (audit ranks 21/22). Real providers keep issuer + endpoints on
|
||||
// the same host, so derive them all from the server URL.
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://provider.example.com",
|
||||
AuthURL: "https://provider.example.com/auth",
|
||||
TokenURL: "https://provider.example.com/token",
|
||||
JWKSURL: "https://provider.example.com/jwks",
|
||||
RevokeURL: "https://provider.example.com/revoke",
|
||||
EndSessionURL: "https://provider.example.com/logout",
|
||||
Issuer: srv.URL,
|
||||
AuthURL: srv.URL + "/auth",
|
||||
TokenURL: srv.URL + "/token",
|
||||
JWKSURL: srv.URL + "/jwks",
|
||||
RevokeURL: srv.URL + "/revoke",
|
||||
EndSessionURL: srv.URL + "/logout",
|
||||
})
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
},
|
||||
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
|
||||
if oidc.authURL != "https://provider.example.com/auth" {
|
||||
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
|
||||
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
|
||||
}
|
||||
if oidc.tokenURL != "https://provider.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
|
||||
}
|
||||
if oidc.jwksURL != "https://provider.example.com/jwks" {
|
||||
if oidc.jwksURL == "" || !strings.HasSuffix(oidc.jwksURL, "/jwks") {
|
||||
t.Errorf("expected jwksURL to be set, got %s", oidc.jwksURL)
|
||||
}
|
||||
if oidc.revocationURL != "https://provider.example.com/revoke" {
|
||||
if oidc.revocationURL == "" || !strings.HasSuffix(oidc.revocationURL, "/revoke") {
|
||||
t.Errorf("expected revocationURL to be set, got %s", oidc.revocationURL)
|
||||
}
|
||||
if oidc.endSessionURL != "https://provider.example.com/logout" {
|
||||
if oidc.endSessionURL == "" || !strings.HasSuffix(oidc.endSessionURL, "/logout") {
|
||||
t.Errorf("expected endSessionURL to be set, got %s", oidc.endSessionURL)
|
||||
}
|
||||
if oidc.issuerURL == "" {
|
||||
t.Errorf("expected issuerURL to be pinned to provider host, got empty")
|
||||
}
|
||||
},
|
||||
wantPanic: false,
|
||||
},
|
||||
@@ -115,24 +125,27 @@ func TestInitializeMetadata(t *testing.T) {
|
||||
name: "partial metadata response",
|
||||
providerURL: "",
|
||||
setupMock: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Only return some fields
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"issuer": "https://partial.example.com",
|
||||
"authorization_endpoint": "https://partial.example.com/auth",
|
||||
"token_endpoint": "https://partial.example.com/token",
|
||||
"issuer": srv.URL,
|
||||
"authorization_endpoint": srv.URL + "/auth",
|
||||
"token_endpoint": srv.URL + "/token",
|
||||
// Missing jwks_uri, revocation_endpoint, end_session_endpoint
|
||||
})
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
},
|
||||
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
|
||||
if oidc.authURL != "https://partial.example.com/auth" {
|
||||
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
|
||||
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
|
||||
}
|
||||
if oidc.tokenURL != "https://partial.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
|
||||
}
|
||||
// JWKS URL and others may be empty
|
||||
@@ -197,20 +210,22 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
|
||||
requestCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
requestCount++
|
||||
mu.Unlock()
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://concurrent.example.com",
|
||||
AuthURL: "https://concurrent.example.com/auth",
|
||||
TokenURL: "https://concurrent.example.com/token",
|
||||
JWKSURL: "https://concurrent.example.com/jwks",
|
||||
RevokeURL: "https://concurrent.example.com/revoke",
|
||||
EndSessionURL: "https://concurrent.example.com/logout",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
RevokeURL: server.URL + "/revoke",
|
||||
EndSessionURL: server.URL + "/logout",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -249,7 +264,7 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
|
||||
oidc.initializeMetadata(server.URL)
|
||||
|
||||
// Verify initialization
|
||||
if oidc.tokenURL != "https://concurrent.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set")
|
||||
}
|
||||
}()
|
||||
@@ -341,17 +356,19 @@ func TestProviderDetection(t *testing.T) {
|
||||
// TestInitializationWaiting tests waiting for initialization to complete
|
||||
func TestInitializationWaiting(t *testing.T) {
|
||||
t.Run("wait for initialization completion", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Delay response to simulate slow initialization
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://slow.example.com",
|
||||
AuthURL: "https://slow.example.com/auth",
|
||||
TokenURL: "https://slow.example.com/token",
|
||||
JWKSURL: "https://slow.example.com/jwks",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -388,7 +405,7 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
select {
|
||||
case <-oidc.initComplete:
|
||||
// Success
|
||||
if oidc.tokenURL != "https://slow.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Error("expected tokenURL to be set after initialization")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
@@ -397,17 +414,19 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("multiple waiters for initialization", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Delay to ensure multiple waiters
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://multi.example.com",
|
||||
AuthURL: "https://multi.example.com/auth",
|
||||
TokenURL: "https://multi.example.com/token",
|
||||
JWKSURL: "https://multi.example.com/jwks",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -452,7 +471,7 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
select {
|
||||
case <-oidc.initComplete:
|
||||
// All waiters should see the same initialized state
|
||||
if oidc.tokenURL != "https://multi.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("waiter %d: expected tokenURL to be set", id)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
@@ -484,9 +503,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
providerURL: server.URL,
|
||||
firstRequestReceived: false,
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
providerURL: server.URL,
|
||||
firstRequestStarted: 0,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
@@ -508,19 +526,13 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate first request processing
|
||||
oidc.firstRequestMutex.Lock()
|
||||
if !oidc.firstRequestReceived {
|
||||
oidc.firstRequestReceived = true
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
|
||||
// Simulate first request processing — single-firing via CAS.
|
||||
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
|
||||
// This would normally be called asynchronously
|
||||
go func() {
|
||||
oidc.initializeMetadata(server.URL)
|
||||
// initComplete is closed internally by initializeMetadata
|
||||
}()
|
||||
} else {
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Wait for initialization
|
||||
@@ -556,9 +568,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
providerURL: server.URL,
|
||||
firstRequestReceived: false,
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
providerURL: server.URL,
|
||||
firstRequestStarted: 0,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
@@ -580,31 +591,22 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate multiple concurrent "first" requests
|
||||
// Simulate multiple concurrent "first" requests — only one CAS winner
|
||||
// fires the bootstrap path.
|
||||
const numRequests = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numRequests)
|
||||
|
||||
initStarted := 0
|
||||
var initMu sync.Mutex
|
||||
var initStarted int32
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
oidc.firstRequestMutex.Lock()
|
||||
if !oidc.firstRequestReceived {
|
||||
oidc.firstRequestReceived = true
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
|
||||
initMu.Lock()
|
||||
initStarted++
|
||||
initMu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
|
||||
atomic.AddInt32(&initStarted, 1)
|
||||
// Only one should actually start initialization
|
||||
oidc.initializeMetadata(server.URL)
|
||||
} else {
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -612,8 +614,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
// Verify only one initialization was started
|
||||
if initStarted != 1 {
|
||||
t.Errorf("expected exactly 1 initialization, got %d", initStarted)
|
||||
if atomic.LoadInt32(&initStarted) != 1 {
|
||||
t.Errorf("expected exactly 1 initialization, got %d", atomic.LoadInt32(&initStarted))
|
||||
}
|
||||
|
||||
// The metadata endpoint might be called once or not at all depending on timing
|
||||
|
||||
+549
-56
@@ -61,8 +61,8 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com", // Required for initialization check
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTP_EventStream tests the event-stream bypass functionality
|
||||
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
|
||||
// handshake must skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but it must STILL require an authenticated session, otherwise
|
||||
// any caller could reach the backend by setting Accept: text/event-stream.
|
||||
func TestServeHTTP_EventStream(t *testing.T) {
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
|
||||
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
|
||||
}
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
// Build an authenticated session and inject its cookies onto req.
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated SSE request to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
|
||||
// handshake bypasses the OIDC redirect (clients can't follow it) but the
|
||||
// session must already be authenticated, otherwise the backend is exposed
|
||||
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
|
||||
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}))
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected event-stream request to bypass OIDC")
|
||||
}
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
// Mixed-case + multi-token Connection header to exercise parsing.
|
||||
req.Header.Set("Connection", "keep-alive, Upgrade")
|
||||
req.Header.Set("Upgrade", "WebSocket")
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("ws-user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "ws-user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
|
||||
// Sanity: requests without Upgrade headers must NOT hit the WS
|
||||
// bypass branch (otherwise the new code path could short-circuit
|
||||
// normal authentication).
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("backend must not be called for unauthenticated plain HTTP")
|
||||
}))
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code == http.StatusOK {
|
||||
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
|
||||
@@ -120,8 +272,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close this to simulate timeout
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
@@ -155,8 +307,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -185,8 +337,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -215,8 +367,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "successful authorization with email",
|
||||
setupSession: func() *MockSessionData {
|
||||
session := &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "no email triggers reauth",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "",
|
||||
userIdentifier: "",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "roles and groups authorization",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "unauthorized role/group returns 403",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "template headers processing",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "OPTIONS request with CORS",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
manager: &SessionManager{logger: NewLogger("debug")},
|
||||
}
|
||||
// Copy values from mock to concrete session
|
||||
concreteSession.SetEmail(session.email)
|
||||
concreteSession.SetUserIdentifier(session.userIdentifier)
|
||||
concreteSession.SetIDToken(session.idToken)
|
||||
concreteSession.SetAccessToken(session.accessToken)
|
||||
concreteSession.SetRefreshToken(session.refreshToken)
|
||||
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
|
||||
// MockSessionData is a test implementation of SessionData interface
|
||||
type MockSessionData struct {
|
||||
email string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
userIdentifier string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
}
|
||||
|
||||
func (m *MockSessionData) GetEmail() string { return m.email }
|
||||
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
|
||||
func (m *MockSessionData) GetIDToken() string { return m.idToken }
|
||||
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
|
||||
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
|
||||
func (m *MockSessionData) SetEmail(email string) { m.email = email }
|
||||
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
|
||||
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
|
||||
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
|
||||
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
|
||||
@@ -588,8 +740,8 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: tt.minimalHeaders,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set up session data
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Call processAuthorizedRequest directly
|
||||
@@ -665,8 +817,8 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: true, // Enable minimal headers
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -710,3 +862,344 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies tests the stripAuthCookies configuration option.
|
||||
// This addresses GitHub issue #122 - OIDC cookies bloating backend requests.
|
||||
func TestStripAuthCookies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stripAuthCookies bool
|
||||
expectOIDCCookies bool
|
||||
expectAppCookies bool
|
||||
}{
|
||||
{
|
||||
name: "stripAuthCookies=false (default) forwards all cookies",
|
||||
stripAuthCookies: false,
|
||||
expectOIDCCookies: true,
|
||||
expectAppCookies: true,
|
||||
},
|
||||
{
|
||||
name: "stripAuthCookies=true strips OIDC cookies but keeps app cookies",
|
||||
stripAuthCookies: true,
|
||||
expectOIDCCookies: false,
|
||||
expectAppCookies: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
cookiePrefix := sessionManager.GetCookiePrefix()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: tt.stripAuthCookies,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Get a valid session first (before adding fake cookies)
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Now add OIDC session cookies (simulating what the browser would send)
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_1", Value: "chunk1"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "r", Value: "refresh-token"})
|
||||
|
||||
// Add non-OIDC application cookies (these must always pass through)
|
||||
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "app-session-id"})
|
||||
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
// Check for OIDC cookies in captured cookies
|
||||
hasOIDCCookie := false
|
||||
hasAppSession := false
|
||||
hasTheme := false
|
||||
for _, c := range capturedCookies {
|
||||
if len(c.Name) >= len(cookiePrefix) && c.Name[:len(cookiePrefix)] == cookiePrefix {
|
||||
hasOIDCCookie = true
|
||||
}
|
||||
if c.Name == "my_app_session" {
|
||||
hasAppSession = true
|
||||
}
|
||||
if c.Name == "theme" {
|
||||
hasTheme = true
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectOIDCCookies && !hasOIDCCookie {
|
||||
t.Error("expected OIDC cookies to be forwarded to backend")
|
||||
}
|
||||
if !tt.expectOIDCCookies && hasOIDCCookie {
|
||||
t.Error("expected OIDC cookies to be stripped before forwarding to backend")
|
||||
}
|
||||
|
||||
if tt.expectAppCookies && !hasAppSession {
|
||||
t.Error("expected my_app_session cookie to be forwarded to backend")
|
||||
}
|
||||
if tt.expectAppCookies && !hasTheme {
|
||||
t.Error("expected theme cookie to be forwarded to backend")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies_NoCookies verifies stripping works when the request has no cookies.
|
||||
func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
if len(capturedCookies) != 0 {
|
||||
t.Errorf("expected no cookies, got %d", len(capturedCookies))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies_OnlyOIDCCookies verifies that when all cookies are OIDC cookies,
|
||||
// the Cookie header is empty after stripping.
|
||||
func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
var capturedCookieHeader string
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookieHeader = r.Header.Get("Cookie")
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
cookiePrefix := sessionManager.GetCookiePrefix()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only OIDC cookies
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
|
||||
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
if len(capturedCookies) != 0 {
|
||||
t.Errorf("expected all cookies to be stripped, got %d", len(capturedCookies))
|
||||
}
|
||||
if capturedCookieHeader != "" {
|
||||
t.Errorf("expected empty Cookie header, got %q", capturedCookieHeader)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies_OnlyAppCookies verifies that non-OIDC cookies pass through
|
||||
// untouched when stripping is enabled.
|
||||
func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only non-OIDC cookies
|
||||
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "abc123"})
|
||||
req.AddCookie(&http.Cookie{Name: "lang", Value: "en"})
|
||||
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
if len(capturedCookies) != 3 {
|
||||
t.Errorf("expected 3 cookies, got %d", len(capturedCookies))
|
||||
}
|
||||
|
||||
cookieNames := make(map[string]bool)
|
||||
for _, c := range capturedCookies {
|
||||
cookieNames[c.Name] = true
|
||||
}
|
||||
for _, expected := range []string{"my_app_session", "lang", "theme"} {
|
||||
if !cookieNames[expected] {
|
||||
t.Errorf("expected cookie %q to be forwarded", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestStripAuthCookies_CustomPrefix verifies stripping works with a custom cookie prefix.
|
||||
func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
var capturedCookies []*http.Cookie
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCookies = r.Cookies()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create session manager with custom prefix
|
||||
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "myapp_oidc_", 0, NewLogger("debug"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
customPrefix := sm.GetCookiePrefix()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add cookies with the custom prefix (should be stripped)
|
||||
req.AddCookie(&http.Cookie{Name: customPrefix + "m", Value: "session-data"})
|
||||
req.AddCookie(&http.Cookie{Name: customPrefix + "s_0", Value: "chunk0"})
|
||||
|
||||
// Add default-prefix cookie (should NOT be stripped — different prefix)
|
||||
req.AddCookie(&http.Cookie{Name: "_oidc_raczylo_m", Value: "other-session"})
|
||||
|
||||
// Add app cookie (should NOT be stripped)
|
||||
req.AddCookie(&http.Cookie{Name: "my_app", Value: "val"})
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
cookieNames := make(map[string]bool)
|
||||
for _, c := range capturedCookies {
|
||||
cookieNames[c.Name] = true
|
||||
}
|
||||
|
||||
// Custom prefix cookies should be stripped
|
||||
if cookieNames[customPrefix+"m"] {
|
||||
t.Errorf("expected cookie %q to be stripped", customPrefix+"m")
|
||||
}
|
||||
if cookieNames[customPrefix+"s_0"] {
|
||||
t.Errorf("expected cookie %q to be stripped", customPrefix+"s_0")
|
||||
}
|
||||
|
||||
// Default prefix cookie should pass through (different prefix)
|
||||
if !cookieNames["_oidc_raczylo_m"] {
|
||||
t.Error("expected _oidc_raczylo_m cookie to pass through (different prefix)")
|
||||
}
|
||||
|
||||
// App cookie should pass through
|
||||
if !cookieNames["my_app"] {
|
||||
t.Error("expected my_app cookie to pass through")
|
||||
}
|
||||
}
|
||||
|
||||
+113
-63
@@ -16,6 +16,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -208,6 +209,32 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.Err != nil {
|
||||
return nil, m.Err
|
||||
}
|
||||
if m.JWKS == nil {
|
||||
return nil, fmt.Errorf("JWKS is nil")
|
||||
}
|
||||
for i := range m.JWKS.Keys {
|
||||
k := &m.JWKS.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) Cleanup() {
|
||||
// Mock cleanup is a no-op - we don't want to destroy the mock JWKS data
|
||||
// Real cleanup is for expired entries, not resetting all data
|
||||
@@ -554,7 +581,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -577,7 +604,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
// even if session.SetAuthenticated(true) was called.
|
||||
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -634,7 +661,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -652,7 +679,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -680,7 +707,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -715,7 +742,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
||||
},
|
||||
@@ -746,7 +773,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(validToken)
|
||||
session.SetIDToken(validToken) // Ensure ID token is also set
|
||||
session.SetRefreshToken("should-not-be-used-refresh-token")
|
||||
@@ -766,7 +793,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -788,7 +815,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -1848,14 +1875,14 @@ func TestHandleLogout(t *testing.T) {
|
||||
},
|
||||
endSessionURL: "",
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
expectedURL: "/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
name: "Logout with empty session",
|
||||
setupSession: func(session *SessionData) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
expectedURL: "/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
@@ -2153,7 +2180,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAccessToken(expiredToken)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
},
|
||||
expectedPath: "/original/path",
|
||||
},
|
||||
@@ -2322,19 +2349,22 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
t.Skip("Skipping test in short mode")
|
||||
}
|
||||
|
||||
// Create mock provider metadata server
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create mock provider metadata server. Issuer + endpoints must share the
|
||||
// host with ProviderURL (the httptest server), otherwise the discovery doc
|
||||
// is rejected as poisoned (audit ranks 21/22). Derive them from the server.
|
||||
var mockServer *httptest.Server
|
||||
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
RevokeURL: "https://test-issuer.com/revoke",
|
||||
EndSessionURL: "https://test-issuer.com/end-session",
|
||||
Issuer: mockServer.URL,
|
||||
AuthURL: mockServer.URL + "/auth",
|
||||
TokenURL: mockServer.URL + "/token",
|
||||
JWKSURL: mockServer.URL + "/jwks",
|
||||
RevokeURL: mockServer.URL + "/revoke",
|
||||
EndSessionURL: mockServer.URL + "/end-session",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
@@ -2347,6 +2377,7 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create multiple middleware instances
|
||||
@@ -2387,18 +2418,20 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
t.Fatalf("Middleware instance %d failed to initialize", i)
|
||||
}
|
||||
|
||||
// Verify each instance has its own unique configuration
|
||||
if m.issuerURL != "https://test-issuer.com" {
|
||||
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
|
||||
// Verify each instance has its own unique configuration. Issuer is now
|
||||
// pinned to the provider host (audit ranks 21/22), so it equals the
|
||||
// mock server URL rather than a fixed literal.
|
||||
if m.issuerURL != mockServer.URL {
|
||||
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, mockServer.URL, m.issuerURL)
|
||||
}
|
||||
if m.authURL != "https://test-issuer.com/auth" {
|
||||
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
|
||||
if m.authURL != mockServer.URL+"/auth" {
|
||||
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, mockServer.URL+"/auth", m.authURL)
|
||||
}
|
||||
if m.tokenURL != "https://test-issuer.com/token" {
|
||||
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
|
||||
if m.tokenURL != mockServer.URL+"/token" {
|
||||
t.Errorf("Instance %d: Expected token URL %s, got %s", i, mockServer.URL+"/token", m.tokenURL)
|
||||
}
|
||||
if m.jwksURL != "https://test-issuer.com/jwks" {
|
||||
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
|
||||
if m.jwksURL != mockServer.URL+"/jwks" {
|
||||
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, mockServer.URL+"/jwks", m.jwksURL)
|
||||
}
|
||||
if m.redirURLPath != routes[i]+"/callback" {
|
||||
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
|
||||
@@ -2412,15 +2445,16 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
|
||||
m.ServeHTTP(rr, req)
|
||||
|
||||
// Should redirect to auth URL since not authenticated
|
||||
// Should redirect (302) to the auth flow since not authenticated. The
|
||||
// absolute auth URL is not asserted here: with issuer pinning (audit
|
||||
// ranks 21/22) the discovery host equals the httptest server host,
|
||||
// which is loopback, so buildAuthURL's SSRF guard legitimately refuses
|
||||
// to emit a loopback authorization URL in this test environment. The
|
||||
// per-instance auth/token/jwks/issuer URLs were already verified above;
|
||||
// here we only confirm each instance independently triggers a redirect.
|
||||
if rr.Code != http.StatusFound {
|
||||
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
|
||||
}
|
||||
|
||||
location := rr.Header().Get("Location")
|
||||
if !strings.Contains(location, "https://test-issuer.com/auth") {
|
||||
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2433,33 +2467,43 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create two mock provider metadata servers simulating different Keycloak realms
|
||||
realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer + endpoints must share the host with each realm's ProviderURL
|
||||
// (the httptest server), otherwise the discovery doc is rejected as
|
||||
// poisoned (audit ranks 21/22). Keep the distinguishing /realms/realmN
|
||||
// path so the per-realm isolation assertions below still hold, but base
|
||||
// the host on the server URL — which is exactly what a same-host Keycloak
|
||||
// deployment looks like.
|
||||
var realm1Server *httptest.Server
|
||||
realm1Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
base := realm1Server.URL + "/realms/realm1"
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://keycloak.example.com/realms/realm1",
|
||||
AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth",
|
||||
TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token",
|
||||
JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs",
|
||||
EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout",
|
||||
Issuer: base,
|
||||
AuthURL: base + "/protocol/openid-connect/auth",
|
||||
TokenURL: base + "/protocol/openid-connect/token",
|
||||
JWKSURL: base + "/protocol/openid-connect/certs",
|
||||
EndSessionURL: base + "/protocol/openid-connect/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
defer realm1Server.Close()
|
||||
|
||||
realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var realm2Server *httptest.Server
|
||||
realm2Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
base := realm2Server.URL + "/realms/realm2"
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://keycloak.example.com/realms/realm2",
|
||||
AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth",
|
||||
TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token",
|
||||
JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs",
|
||||
EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout",
|
||||
Issuer: base,
|
||||
AuthURL: base + "/protocol/openid-connect/auth",
|
||||
TokenURL: base + "/protocol/openid-connect/token",
|
||||
JWKSURL: base + "/protocol/openid-connect/certs",
|
||||
EndSessionURL: base + "/protocol/openid-connect/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
@@ -2473,6 +2517,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
CallbackURL: "/realm1/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
CookiePrefix: "_oidc_realm1_",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Config for realm2
|
||||
@@ -2483,6 +2528,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
CallbackURL: "/realm2/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
CookiePrefix: "_oidc_realm2_",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware instances for both realms
|
||||
@@ -2581,8 +2627,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
providerAvailable := false
|
||||
var mu sync.Mutex
|
||||
|
||||
// Create mock provider that initially fails, then becomes available
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create mock provider that initially fails, then becomes available.
|
||||
// Issuer + endpoints must share the host with ProviderURL (audit ranks
|
||||
// 21/22), so derive them from the server URL.
|
||||
var mockServer *httptest.Server
|
||||
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
available := providerAvailable
|
||||
mu.Unlock()
|
||||
@@ -2594,11 +2643,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
EndSessionURL: "https://test-issuer.com/logout",
|
||||
Issuer: mockServer.URL,
|
||||
AuthURL: mockServer.URL + "/auth",
|
||||
TokenURL: mockServer.URL + "/token",
|
||||
JWKSURL: mockServer.URL + "/jwks",
|
||||
EndSessionURL: mockServer.URL + "/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
return
|
||||
@@ -2613,6 +2662,7 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware while provider is unavailable
|
||||
@@ -2659,10 +2709,9 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
providerAvailable = true
|
||||
mu.Unlock()
|
||||
|
||||
// Reset the retry timer to allow immediate retry
|
||||
m.metadataRetryMutex.Lock()
|
||||
m.lastMetadataRetryTime = time.Time{} // Reset to zero time
|
||||
m.metadataRetryMutex.Unlock()
|
||||
// Reset the retry timer to allow immediate retry. The field is atomic
|
||||
// now, so no lock is needed.
|
||||
atomic.StoreInt64(&m.lastMetadataRetryNano, 0)
|
||||
|
||||
// Second request should trigger recovery attempt
|
||||
req2 := httptest.NewRequest("GET", "/protected", nil)
|
||||
@@ -2730,7 +2779,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2756,7 +2805,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2783,7 +2832,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
@@ -2803,7 +2852,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2825,7 +2874,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
@@ -4526,6 +4575,7 @@ func TestNewWithScopeAppending(t *testing.T) {
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
Scopes: tc.configScopes,
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware instance
|
||||
|
||||
+24
-15
@@ -9,13 +9,18 @@ import (
|
||||
// LazyBackgroundTask wraps BackgroundTask to provide delayed initialization.
|
||||
// This prevents memory leaks from unnecessary background tasks by starting
|
||||
// them only when actually needed, reducing resource usage in idle scenarios.
|
||||
//
|
||||
// Lifecycle is one-shot: once Stop has been called the task cannot be
|
||||
// restarted. The underlying BackgroundTask uses sync.Once for Start and
|
||||
// refuses to re-run after Stop, so restart is not supported by design.
|
||||
type LazyBackgroundTask struct {
|
||||
// BackgroundTask is the underlying task implementation
|
||||
*BackgroundTask
|
||||
// started tracks whether the task has been activated
|
||||
// mu guards the started flag against concurrent StartIfNeeded / Stop calls.
|
||||
mu sync.Mutex
|
||||
// started tracks whether the task has been activated.
|
||||
// Only mutated while holding mu.
|
||||
started bool
|
||||
// startOnce ensures single initialization
|
||||
startOnce sync.Once
|
||||
}
|
||||
|
||||
// NewLazyBackgroundTask creates a background task that doesn't start immediately.
|
||||
@@ -29,24 +34,28 @@ func NewLazyBackgroundTask(name string, interval time.Duration, taskFunc func(),
|
||||
}
|
||||
|
||||
// StartIfNeeded starts the background task only if it hasn't been started yet.
|
||||
// Uses sync.Once to ensure thread-safe single initialization.
|
||||
// Safe to call concurrently. After Stop has been called this is a no-op;
|
||||
// the task is not restartable.
|
||||
func (lt *LazyBackgroundTask) StartIfNeeded() {
|
||||
lt.startOnce.Do(func() {
|
||||
if !lt.started {
|
||||
lt.BackgroundTask.Start()
|
||||
lt.started = true
|
||||
}
|
||||
})
|
||||
lt.mu.Lock()
|
||||
defer lt.mu.Unlock()
|
||||
if lt.started {
|
||||
return
|
||||
}
|
||||
lt.BackgroundTask.Start()
|
||||
lt.started = true
|
||||
}
|
||||
|
||||
// Stop stops the background task if it was started.
|
||||
// Resets the start state to allow potential future re-initialization.
|
||||
// Once stopped, the task cannot be restarted (see type doc).
|
||||
func (lt *LazyBackgroundTask) Stop() {
|
||||
if lt.started {
|
||||
lt.BackgroundTask.Stop()
|
||||
lt.started = false
|
||||
lt.startOnce = sync.Once{}
|
||||
lt.mu.Lock()
|
||||
defer lt.mu.Unlock()
|
||||
if !lt.started {
|
||||
return
|
||||
}
|
||||
lt.BackgroundTask.Stop()
|
||||
lt.started = false
|
||||
}
|
||||
|
||||
// NewLazyCacheWithLogger creates a cache that doesn't start cleanup until first use.
|
||||
|
||||
@@ -1652,6 +1652,7 @@ func TestGoroutineLeaks(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client"
|
||||
config.ClientSecret = "test-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(context.Background(), nil, config, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
+142
-12
@@ -58,13 +58,21 @@ func (mpl MemoryPressureLevel) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// MemoryMonitor provides comprehensive memory monitoring and alerting
|
||||
// MemoryMonitor provides comprehensive memory monitoring and alerting.
|
||||
//
|
||||
// Memory sampling is expensive: runtime.ReadMemStats is a stop-the-world
|
||||
// operation. To keep latency predictable the monitor caches the most recent
|
||||
// sample and only refreshes it when the background ticker fires, when TriggerGC
|
||||
// is invoked, or when a caller explicitly calls Refresh(). GetCurrentStats is a
|
||||
// cheap read of that cached sample.
|
||||
type MemoryMonitor struct {
|
||||
lastGCTime time.Time
|
||||
startTime time.Time
|
||||
lastStats *MemoryStats
|
||||
cachedMemStats runtime.MemStats
|
||||
logger *Logger
|
||||
alertThresholds MemoryAlertThresholds
|
||||
config MemoryMonitorConfig
|
||||
baselineGoroutines int
|
||||
baselineHeap uint64
|
||||
heapGrowthRate float64
|
||||
@@ -84,6 +92,30 @@ type MemoryAlertThresholds struct {
|
||||
GCFrequency float64 // Alert when GC frequency exceeds this per minute
|
||||
}
|
||||
|
||||
// MemoryMonitorConfig configures the memory monitor's scheduling behavior.
|
||||
// Thresholds are kept separate in MemoryAlertThresholds.
|
||||
type MemoryMonitorConfig struct {
|
||||
// Interval between background samples. Must be >= MinMemoryMonitorInterval
|
||||
// (30s). Values below the minimum are clamped when monitoring starts.
|
||||
Interval time.Duration
|
||||
}
|
||||
|
||||
// Default and minimum interval values. The minimum exists because
|
||||
// runtime.ReadMemStats is stop-the-world and hammering it on a hot loop causes
|
||||
// noticeable latency spikes, especially under yaegi.
|
||||
const (
|
||||
DefaultMemoryMonitorInterval = 60 * time.Second
|
||||
MinMemoryMonitorInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
// DefaultMemoryMonitorConfig returns a config with sensible production
|
||||
// defaults.
|
||||
func DefaultMemoryMonitorConfig() MemoryMonitorConfig {
|
||||
return MemoryMonitorConfig{
|
||||
Interval: DefaultMemoryMonitorInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMemoryAlertThresholds returns sensible default alert thresholds
|
||||
func DefaultMemoryAlertThresholds() MemoryAlertThresholds {
|
||||
return MemoryAlertThresholds{
|
||||
@@ -95,35 +127,82 @@ func DefaultMemoryAlertThresholds() MemoryAlertThresholds {
|
||||
}
|
||||
}
|
||||
|
||||
// NewMemoryMonitor creates a new memory monitor
|
||||
// NewMemoryMonitor creates a new memory monitor using default scheduling
|
||||
// configuration. See NewMemoryMonitorWithConfig for full control.
|
||||
func NewMemoryMonitor(logger *Logger, thresholds MemoryAlertThresholds) *MemoryMonitor {
|
||||
return NewMemoryMonitorWithConfig(logger, thresholds, DefaultMemoryMonitorConfig())
|
||||
}
|
||||
|
||||
// NewMemoryMonitorWithConfig creates a new memory monitor with an explicit
|
||||
// scheduling config.
|
||||
//
|
||||
// NOTE: the constructor performs a single runtime.ReadMemStats call to capture
|
||||
// baseline heap / goroutine / GC counters used for leak and growth detection.
|
||||
// This is a one-time stop-the-world cost at startup; all subsequent samples
|
||||
// only happen on the monitoring ticker or on explicit Refresh() calls.
|
||||
func NewMemoryMonitorWithConfig(logger *Logger, thresholds MemoryAlertThresholds, config MemoryMonitorConfig) *MemoryMonitor {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
if config.Interval <= 0 {
|
||||
config.Interval = DefaultMemoryMonitorInterval
|
||||
}
|
||||
|
||||
// One-time initial sample to seed baselines used for growth / leak
|
||||
// detection. All subsequent sampling is gated by the monitoring ticker or
|
||||
// explicit Refresh() calls.
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
return &MemoryMonitor{
|
||||
mm := &MemoryMonitor{
|
||||
logger: logger,
|
||||
startTime: time.Now(),
|
||||
alertThresholds: thresholds,
|
||||
config: config,
|
||||
baselineHeap: memStats.HeapAlloc,
|
||||
baselineGoroutines: runtime.NumGoroutine(),
|
||||
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
|
||||
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
lastGCCount: memStats.NumGC,
|
||||
}
|
||||
mm.cachedMemStats = memStats
|
||||
return mm
|
||||
}
|
||||
|
||||
// GetCurrentStats collects current memory statistics
|
||||
// GetCurrentStats returns the most recently sampled memory statistics.
|
||||
//
|
||||
// This is a cheap cached read: it does NOT call runtime.ReadMemStats. Samples
|
||||
// are refreshed only by the monitoring ticker or by an explicit call to
|
||||
// Refresh(). If no sample has been produced yet, stats derived from the
|
||||
// constructor-time raw sample are returned (with no additional STW cost).
|
||||
func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
|
||||
mm.mu.RLock()
|
||||
stats := mm.lastStats
|
||||
mm.mu.RUnlock()
|
||||
if stats != nil {
|
||||
return stats
|
||||
}
|
||||
return mm.buildStatsFromCache()
|
||||
}
|
||||
|
||||
// Refresh synchronously samples current memory statistics via
|
||||
// runtime.ReadMemStats and updates the cached value. This is the only path
|
||||
// (other than the monitoring ticker and TriggerGC) that pays the stop-the-world
|
||||
// cost. Use it in tests or in callers that explicitly need a fresh sample.
|
||||
func (mm *MemoryMonitor) Refresh() *MemoryStats {
|
||||
return mm.sample()
|
||||
}
|
||||
|
||||
// sample performs a stop-the-world ReadMemStats, updates the cached raw stats,
|
||||
// computes a derived MemoryStats snapshot, and stores it as lastStats.
|
||||
func (mm *MemoryMonitor) sample() *MemoryStats {
|
||||
var memStats runtime.MemStats
|
||||
runtime.ReadMemStats(&memStats)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Calculate GC frequency
|
||||
// Calculate GC frequency relative to the previous snapshot.
|
||||
gcFrequency := 0.0
|
||||
mm.mu.RLock()
|
||||
lastStats := mm.lastStats
|
||||
@@ -168,6 +247,7 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
|
||||
mm.updateHeapGrowthTracking(stats)
|
||||
|
||||
mm.mu.Lock()
|
||||
mm.cachedMemStats = memStats
|
||||
mm.lastStats = stats
|
||||
mm.lastGCCount = memStats.NumGC
|
||||
mm.mu.Unlock()
|
||||
@@ -175,6 +255,35 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
|
||||
return stats
|
||||
}
|
||||
|
||||
// buildStatsFromCache constructs a MemoryStats snapshot from the cached raw
|
||||
// runtime.MemStats without issuing a new ReadMemStats call. Used as a fallback
|
||||
// when GetCurrentStats is called before the first sample() has completed.
|
||||
func (mm *MemoryMonitor) buildStatsFromCache() *MemoryStats {
|
||||
mm.mu.RLock()
|
||||
memStats := mm.cachedMemStats
|
||||
mm.mu.RUnlock()
|
||||
|
||||
stats := &MemoryStats{
|
||||
HeapAllocBytes: memStats.HeapAlloc,
|
||||
HeapSysBytes: memStats.HeapSys,
|
||||
HeapIdleBytes: memStats.HeapIdle,
|
||||
HeapInuseBytes: memStats.HeapInuse,
|
||||
HeapReleasedBytes: memStats.HeapReleased,
|
||||
HeapObjects: memStats.HeapObjects,
|
||||
StackInuseBytes: memStats.StackInuse,
|
||||
StackSysBytes: memStats.StackSys,
|
||||
GCSysBytes: memStats.GCSys,
|
||||
NumGoroutines: runtime.NumGoroutine(),
|
||||
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
|
||||
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
GCFrequency: 0.0,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
mm.collectApplicationStats(stats)
|
||||
stats.MemoryPressure = mm.calculateMemoryPressure(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// collectApplicationStats gathers application-specific memory stats
|
||||
func (mm *MemoryMonitor) collectApplicationStats(stats *MemoryStats) {
|
||||
// Get session count from ChunkManager if available
|
||||
@@ -229,7 +338,7 @@ func (mm *MemoryMonitor) updateGoroutineTracking(stats *MemoryStats) {
|
||||
}
|
||||
|
||||
// Check for potential goroutine leak
|
||||
if stats.NumGoroutines > mm.baselineGoroutines+int(mm.alertThresholds.GoroutineCount) {
|
||||
if stats.NumGoroutines > mm.baselineGoroutines+mm.alertThresholds.GoroutineCount {
|
||||
mm.mu.Lock()
|
||||
wasAlert := mm.goroutineLeakAlert
|
||||
if !wasAlert {
|
||||
@@ -302,7 +411,16 @@ var (
|
||||
globalMonitoringMutex sync.Mutex
|
||||
)
|
||||
|
||||
// StartMonitoring starts continuous memory monitoring as a global singleton
|
||||
// StartMonitoring starts continuous memory monitoring as a global singleton.
|
||||
//
|
||||
// The effective interval is resolved as follows:
|
||||
// 1. If the caller passes a positive interval, that is used.
|
||||
// 2. Otherwise the configured MemoryMonitorConfig.Interval is used.
|
||||
// 3. Otherwise the built-in default (60s) is used.
|
||||
//
|
||||
// The result is then clamped to a minimum of MinMemoryMonitorInterval (30s) to
|
||||
// avoid stop-the-world ReadMemStats storms. Callers that need rapid updates in
|
||||
// tests should call Refresh() directly instead of spinning the ticker fast.
|
||||
func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Duration) {
|
||||
globalMonitoringMutex.Lock()
|
||||
defer globalMonitoringMutex.Unlock()
|
||||
@@ -316,7 +434,17 @@ func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Dura
|
||||
}
|
||||
|
||||
if interval <= 0 {
|
||||
interval = 30 * time.Second
|
||||
interval = mm.config.Interval
|
||||
}
|
||||
if interval <= 0 {
|
||||
interval = DefaultMemoryMonitorInterval
|
||||
}
|
||||
if interval < MinMemoryMonitorInterval {
|
||||
if !isTestMode() {
|
||||
mm.logger.Debug("Memory monitor interval %v is below minimum %v; clamping",
|
||||
interval, MinMemoryMonitorInterval)
|
||||
}
|
||||
interval = MinMemoryMonitorInterval
|
||||
}
|
||||
|
||||
registry := GetGlobalTaskRegistry()
|
||||
@@ -325,7 +453,7 @@ func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Dura
|
||||
"memory-monitor",
|
||||
interval,
|
||||
func() {
|
||||
stats := mm.GetCurrentStats()
|
||||
stats := mm.sample()
|
||||
mm.LogMemoryStats(stats)
|
||||
mm.checkAlerts(stats)
|
||||
},
|
||||
@@ -369,14 +497,16 @@ func (mm *MemoryMonitor) checkAlerts(stats *MemoryStats) {
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerGC forces garbage collection and logs the impact
|
||||
// TriggerGC forces garbage collection and logs the impact. Both the before and
|
||||
// after measurements are fresh samples (explicit Refresh() calls) because the
|
||||
// comparison is meaningless against a stale cached snapshot.
|
||||
func (mm *MemoryMonitor) TriggerGC() {
|
||||
before := mm.GetCurrentStats()
|
||||
before := mm.Refresh()
|
||||
|
||||
runtime.GC()
|
||||
runtime.GC() // Run twice to ensure full collection
|
||||
|
||||
after := mm.GetCurrentStats()
|
||||
after := mm.Refresh()
|
||||
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freedBytes := int64(before.HeapAllocBytes) - int64(after.HeapAllocBytes)
|
||||
|
||||
+11
-1
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -141,10 +142,19 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
|
||||
}
|
||||
|
||||
var metadata ProviderMetadata
|
||||
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode metadata: %w", err)
|
||||
}
|
||||
|
||||
// Pin the advertised issuer to the configured provider host. The issuer is
|
||||
// the trust anchor for JWT issuer validation; rejecting a mismatch here
|
||||
// ensures a poisoned discovery document advertising an attacker-chosen
|
||||
// issuer is never cached or returned. Real providers (Google, Azure,
|
||||
// Keycloak, Okta, Auth0) keep the issuer on the same host as providerURL.
|
||||
if metadata.Issuer != "" && !sameHost(metadata.Issuer, providerURL) {
|
||||
return nil, fmt.Errorf("discovery issuer %q host does not match provider %q", metadata.Issuer, providerURL)
|
||||
}
|
||||
|
||||
// Cache for 1 hour by default
|
||||
if err := mc.Set(providerURL, &metadata, 1*time.Hour); err != nil {
|
||||
mc.logger.Errorf("Failed to cache metadata: %v", err)
|
||||
|
||||
+549
-131
@@ -8,11 +8,105 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// bypassReason describes why a request is being forwarded without OIDC auth.
|
||||
// It is only used for logging and to decide whether extra side-effects
|
||||
// (propagating the user header from an existing session) should run.
|
||||
const (
|
||||
bypassReasonExcluded = "excluded-url"
|
||||
bypassReasonSSE = "sse"
|
||||
bypassReasonWebSocket = "websocket"
|
||||
)
|
||||
|
||||
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
|
||||
// (RFC 6455). The middleware can only see the handshake; once Traefik
|
||||
// completes the upgrade it forwards frames directly, so we never re-process
|
||||
// per-frame traffic. We bypass auth on the handshake the same way we do for
|
||||
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
||||
return false
|
||||
}
|
||||
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
|
||||
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// shouldBypassAuth decides whether a request must skip OIDC authentication
|
||||
// entirely. It returns (true, reason) when either the request path matches a
|
||||
// configured excluded URL, the Accept header asks for a text/event-stream
|
||||
// response (SSE), or the request is a WebSocket upgrade handshake. The
|
||||
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
|
||||
// kind (e.g. propagating user headers).
|
||||
//
|
||||
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
|
||||
// WebSocket traffic is never blocked by a slow/broken provider.
|
||||
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
return true, bypassReasonExcluded
|
||||
}
|
||||
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
|
||||
return true, bypassReasonSSE
|
||||
}
|
||||
if isWebSocketUpgrade(req) {
|
||||
return true, bypassReasonWebSocket
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
|
||||
// requests and, on success, copies the authenticated user's identity onto
|
||||
// the outgoing request so downstream services can see who the user is.
|
||||
//
|
||||
// Returns true when the request carries a valid authenticated session and
|
||||
// the bypass should proceed. Returns false when no usable session is
|
||||
// present; callers must then reject the request (typically with 401) to
|
||||
// prevent unauthenticated traffic from reaching the backend just by setting
|
||||
// `Accept: text/event-stream` or sending a WebSocket upgrade.
|
||||
//
|
||||
// The check is cookie-only: the session cookie is sealed by our encryption
|
||||
// key, so the authenticated flag cannot be forged. We do NOT run full token
|
||||
// signature verification here so that SSE/WS keeps working when the OIDC
|
||||
// provider is briefly unavailable for JWK fetches.
|
||||
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
|
||||
if t.sessionManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
|
||||
return false
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
|
||||
return false
|
||||
}
|
||||
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
}
|
||||
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
|
||||
return true
|
||||
}
|
||||
|
||||
// ServeHTTP implements the main middleware logic for processing HTTP requests.
|
||||
// It handles the complete OIDC authentication flow including:
|
||||
// - Excluded URL bypass
|
||||
@@ -52,55 +146,94 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.URL.Path, "/health") {
|
||||
t.firstRequestMutex.Lock()
|
||||
if !t.firstRequestReceived {
|
||||
t.firstRequestReceived = true
|
||||
// Lock-free one-shot bootstrap. The previous firstRequestMutex.Lock()
|
||||
// fired on EVERY non-health request forever (even after the boolean
|
||||
// flipped true), which under Yaegi added a per-request serialization
|
||||
// point. CAS gives single-firing semantics with zero steady-state cost.
|
||||
if atomic.CompareAndSwapInt32(&t.firstRequestStarted, 0, 1) {
|
||||
t.logger.Debug("Starting background tasks on first request")
|
||||
t.startTokenCleanup()
|
||||
|
||||
if !t.metadataRefreshStarted && t.providerURL != "" {
|
||||
t.metadataRefreshStarted = true
|
||||
if t.providerURL != "" &&
|
||||
atomic.CompareAndSwapInt32(&t.metadataRefreshStartedAtomic, 0, 1) {
|
||||
// Metadata refresh is handled by singleton resource manager
|
||||
t.startMetadataRefresh(t.providerURL)
|
||||
}
|
||||
}
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Check excluded URLs before waiting for initialization
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for SSE requests before waiting for initialization
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/event-stream") {
|
||||
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded
|
||||
// URLs, SSE and WebSocket upgrade requests must not block on provider
|
||||
// init. For SSE/WebSocket we ALSO require an authenticated session
|
||||
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
|
||||
// of in-flight streams can't follow an OIDC redirect, so forwarding
|
||||
// unauthenticated traffic would silently expose the backend.
|
||||
if bypass, reason := t.shouldBypassAuth(req); bypass {
|
||||
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
|
||||
// When bearer auth is enabled, strip the Authorization header on
|
||||
// bypassed paths so a bearer token can't leak into health/metrics/
|
||||
// public endpoint logs via downstream services that don't expect it.
|
||||
// Excluded URLs are explicitly public; bearer is an artifact of the
|
||||
// API auth flow that doesn't belong on them.
|
||||
if t.enableBearerAuth {
|
||||
req.Header.Del("Authorization")
|
||||
}
|
||||
switch reason {
|
||||
case bypassReasonExcluded:
|
||||
// Operator-declared excluded URLs forward unconditionally.
|
||||
t.next.ServeHTTP(rw, req)
|
||||
case bypassReasonSSE, bypassReasonWebSocket:
|
||||
// Skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but still require an authenticated session.
|
||||
// Otherwise an unauthenticated client could hit the backend
|
||||
// just by setting Accept: text/event-stream or sending a
|
||||
// WebSocket upgrade.
|
||||
if !t.applyBypassUserHeaders(req, reason) {
|
||||
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
default:
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Log waiting for initialization to help diagnose hanging requests
|
||||
t.logger.Debug("Waiting for OIDC provider initialization...")
|
||||
|
||||
// time.NewTimer + Stop avoids leaking a goroutine+channel for 30s on every
|
||||
// request when initComplete fires quickly (would happen with time.After).
|
||||
initTimer := time.NewTimer(30 * time.Second)
|
||||
defer initTimer.Stop()
|
||||
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
// Read issuerURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
issuerURL := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
// Read issuerURL via atomic snapshot when available — replaces the
|
||||
// metadataMu.RLock that previously fired on every non-bypass request.
|
||||
// Under Yaegi each RLock acquisition costs 1-5ms of interpreter
|
||||
// dispatch; the snapshot is a single atomic.Value.Load. Falls back
|
||||
// to the legacy field+RLock for paths that haven't published a
|
||||
// snapshot yet (notably some test setups that initialize the struct
|
||||
// fields directly).
|
||||
var issuerURL string
|
||||
if snap := t.metadataSnap(); snap != nil {
|
||||
issuerURL = snap.IssuerURL
|
||||
} else {
|
||||
t.metadataMu.RLock()
|
||||
issuerURL = t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
}
|
||||
|
||||
if issuerURL == "" {
|
||||
// Provider metadata initialization failed - try to recover
|
||||
// Retry every 30 seconds to allow automatic recovery when provider comes back online
|
||||
t.metadataRetryMutex.Lock()
|
||||
shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second
|
||||
if shouldRetry {
|
||||
t.lastMetadataRetryTime = time.Now()
|
||||
}
|
||||
t.metadataRetryMutex.Unlock()
|
||||
// Provider metadata initialization failed - try to recover.
|
||||
// Retry every 30 seconds to allow automatic recovery. Lock-free
|
||||
// throttle via CAS on lastMetadataRetryNano: one goroutine wins
|
||||
// the window, others see shouldRetry=false.
|
||||
nowNano := time.Now().UnixNano()
|
||||
last := atomic.LoadInt64(&t.lastMetadataRetryNano)
|
||||
shouldRetry := time.Duration(nowNano-last) >= 30*time.Second &&
|
||||
atomic.CompareAndSwapInt64(&t.lastMetadataRetryNano, last, nowNano)
|
||||
|
||||
if shouldRetry && t.providerURL != "" {
|
||||
t.logger.Info("Attempting to recover OIDC provider metadata...")
|
||||
@@ -115,36 +248,33 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.logger.Debug("Request canceled while waiting for OIDC initialization")
|
||||
t.sendErrorResponse(rw, req, "Request canceled", http.StatusRequestTimeout)
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
case <-initTimer.C:
|
||||
t.logger.Error("Timeout waiting for OIDC initialization")
|
||||
t.sendErrorResponse(rw, req, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
acceptHeader = req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/event-stream") {
|
||||
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
|
||||
// Set forwarded user headers from existing session before bypassing
|
||||
if session, err := t.sessionManager.GetSession(req); err == nil {
|
||||
defer session.returnToPoolSafely()
|
||||
if email := session.GetEmail(); email != "" {
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
}
|
||||
t.logger.Debugf("SSE bypass: forwarded user %s from session", email)
|
||||
}
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
// Bypass checks already ran before the init wait; no need to repeat them.
|
||||
t.sessionManager.CleanupOldCookies(rw, req)
|
||||
|
||||
// Bearer-token auth (opt-in). Runs after init (we need issuer+JWKs+aud
|
||||
// available) and after bypass (excluded URLs always win). Cookie-vs-
|
||||
// bearer precedence is configurable; the safe default is cookie-wins.
|
||||
// See bearer_auth.go for the full pipeline.
|
||||
if t.enableBearerAuth {
|
||||
if _, hasBearer := detectBearerToken(req); hasBearer {
|
||||
cookiePresent := t.hasSessionCookie(req)
|
||||
if !cookiePresent || t.bearerOverridesCookie {
|
||||
if cookiePresent {
|
||||
t.logger.Infof("Both Authorization: Bearer and session cookie present on %s; bearer-wins per BearerOverridesCookie=true", req.URL.Path)
|
||||
}
|
||||
t.handleBearerRequest(rw, req)
|
||||
return
|
||||
}
|
||||
t.logger.Infof("Both Authorization: Bearer and session cookie present on %s; cookie-wins (default); bearer ignored", req.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v. Initiating authentication.", err)
|
||||
@@ -160,6 +290,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.sendErrorResponse(rw, req, "Critical session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Sub-resource requests (script/image/fetch/serviceWorker) must not
|
||||
// trigger an OIDC redirect from this path either: they would overwrite
|
||||
// any in-flight CSRF/nonce in the session. Let the next HTML navigation
|
||||
// initiate the flow. See issue #129.
|
||||
if t.isAjaxRequest(req) || t.isNonNavigationRequest(req) {
|
||||
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
scheme := utils.DetermineScheme(req, t.forceHTTPS)
|
||||
host := utils.DetermineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
@@ -173,12 +311,32 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
host := utils.DetermineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
// Capture per-request state: one RLock on sd.sessionMutex covers all the
|
||||
// getter values the handler chain needs (instead of 5-7 separate
|
||||
// session.GetX() calls each acquiring their own RLock under Yaegi).
|
||||
// metadataSnap is also stored once so downstream handlers don't repeat
|
||||
// the atomic.Value.Load.
|
||||
rs := (&requestState{
|
||||
scheme: scheme,
|
||||
host: host,
|
||||
redirectURL: redirectURL,
|
||||
next: t.next,
|
||||
metadata: t.metadataSnap(),
|
||||
}).captureSession(session)
|
||||
|
||||
// Check if the current request is the OIDC callback
|
||||
t.logger.Debugf("Checking callback URL match: request_path=%q, configured_callback=%q", req.URL.Path, t.redirURLPath)
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
t.logger.Debugf("Callback URL matched, processing OIDC callback (redirect_url=%s)", redirectURL)
|
||||
t.handleCallback(rw, req, redirectURL)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Callback URL did not match (request_path=%q != configured=%q), continuing auth flow", req.URL.Path, t.redirURLPath)
|
||||
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
||||
// Token validation reads session via the captured snapshot — saves ~21
|
||||
// sd.sessionMutex.RLock acquisitions (Yaegi-dispatched, ~1-5ms each)
|
||||
// across the validation path.
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticatedRS(rs)
|
||||
|
||||
if expired {
|
||||
t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth")
|
||||
@@ -186,7 +344,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
|
||||
userIdentifier := rs.userIdentifier
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
@@ -203,14 +361,18 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// methods (validateAzureTokens/validateStandardTokens) before reaching this point.
|
||||
// Redundant validation here was causing issues with Azure AD tokens that have
|
||||
// JWT format but unverifiable signatures. See issue #89.
|
||||
t.processAuthorizedRequest(rw, req, session, redirectURL)
|
||||
t.processAuthorizedRequestRS(rw, req, rs)
|
||||
return
|
||||
}
|
||||
|
||||
refreshTokenPresent := session.GetRefreshToken() != ""
|
||||
refreshTokenPresent := rs.refreshToken != ""
|
||||
|
||||
// Check if this is an AJAX request that should receive 401 instead of redirect
|
||||
isAjaxRequest := t.isAjaxRequest(req)
|
||||
// Decide whether to answer with 401 instead of a redirect. AJAX requests
|
||||
// cannot follow a 302 into an IdP, and sub-resource loads (script/image/
|
||||
// fetch/serviceWorker) must not trigger a fresh OIDC flow because parallel
|
||||
// loads would each overwrite the session CSRF/nonce (issue #129). Only
|
||||
// top-level HTML navigations should redirect.
|
||||
isAjaxRequest := t.isAjaxRequest(req) || t.isNonNavigationRequest(req)
|
||||
|
||||
// Check if refresh token is likely expired (older than 6 hours)
|
||||
refreshTokenExpired := refreshTokenPresent && t.isRefreshTokenExpired(session)
|
||||
@@ -254,7 +416,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
|
||||
userIdentifier = session.GetUserIdentifier()
|
||||
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
@@ -294,19 +456,116 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// processAuthorizedRequest processes requests for authenticated users.
|
||||
// It extracts claims, validates roles/groups if configured, sets authentication headers,
|
||||
// processes header templates, and forwards the request to the next handler.
|
||||
// Domain checks should be performed before calling this method.
|
||||
// processAuthorizedRequest processes requests for authenticated cookie/session
|
||||
// users. It performs session-specific checks (identifier presence, backchannel-
|
||||
// logout invalidation, claims extraction with potential re-auth), persists
|
||||
// dirty session state, then delegates the post-auth pipeline (roles/groups,
|
||||
// header injection, security headers, cookie strip, forward) to
|
||||
// forwardAuthorized.
|
||||
//
|
||||
// The bearer-token path uses the same forwardAuthorized helper but takes a
|
||||
// different route to it (see bearer_auth.go). Keeping forwardAuthorized
|
||||
// session-agnostic is what lets the two auth methods share one pipeline.
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The HTTP request to process.
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
//
|
||||
// processAuthorizedRequestRS is the requestState-aware variant of
|
||||
// processAuthorizedRequest. It reads SessionData fields from the captured
|
||||
// snapshot in rs instead of calling session.GetX() (each of which acquires
|
||||
// sd.sessionMutex.RLock — under Yaegi every RLock pays ~1-5ms of interpreter
|
||||
// dispatch). Only session-mutating operations (Save, ResetRedirectCount,
|
||||
// Clear, IsDirty) still go through the session pointer because those write
|
||||
// state and have no snapshot.
|
||||
func (t *TraefikOidc) processAuthorizedRequestRS(rw http.ResponseWriter, req *http.Request, rs *requestState) {
|
||||
session := rs.session
|
||||
redirectURL := rs.redirectURL
|
||||
userIdentifier := rs.userIdentifier
|
||||
if userIdentifier == "" {
|
||||
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if session has been invalidated via backchannel or front-channel logout
|
||||
idToken := rs.idToken
|
||||
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
}
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve ID-token claims at most once per request. SessionData caches
|
||||
// the parsed claims keyed on the raw ID token.
|
||||
var (
|
||||
idClaims map[string]interface{}
|
||||
idClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
}
|
||||
|
||||
var (
|
||||
groupClaims map[string]interface{}
|
||||
groupClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
groupClaims, groupClaimsErr = idClaims, idClaimsErr
|
||||
} else if rs.accessToken != "" {
|
||||
groupClaims, groupClaimsErr = t.extractClaimsFunc(rs.accessToken)
|
||||
} else if len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if groupClaimsErr != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract claims for roles/groups check: %v", groupClaimsErr)
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Persist any dirty session state BEFORE forwardAuthorized writes the
|
||||
// response.
|
||||
if session.IsDirty() {
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session after processing headers: %v", err)
|
||||
}
|
||||
} else {
|
||||
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
|
||||
}
|
||||
|
||||
p := &principal{
|
||||
Source: sourceSession,
|
||||
Identifier: userIdentifier,
|
||||
AccessToken: rs.accessToken,
|
||||
IDToken: idToken,
|
||||
RefreshToken: rs.refreshToken,
|
||||
Claims: groupClaims,
|
||||
}
|
||||
|
||||
t.forwardAuthorized(rw, req, p)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Info("No email found in session during final processing, initiating re-auth")
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
|
||||
// Reset redirect count to prevent loops when session is invalid
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
@@ -319,7 +578,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
|
||||
// Clear the session and redirect to login
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
@@ -331,36 +590,159 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
|
||||
tokenForClaims := session.GetIDToken()
|
||||
if tokenForClaims == "" {
|
||||
tokenForClaims = session.GetAccessToken()
|
||||
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
// Reset redirect count to prevent loops when token is missing
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
// Resolve ID-token claims at most once per request. SessionData caches
|
||||
// the parsed claims keyed on the raw ID token, so concurrent dashboard
|
||||
// panel requests on the same session don't repeatedly base64-decode and
|
||||
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
|
||||
// that hosts Traefik plugins).
|
||||
idToken := session.GetIDToken()
|
||||
var (
|
||||
idClaims map[string]interface{}
|
||||
idClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
}
|
||||
|
||||
// Initialize empty slices
|
||||
var groups, roles []string
|
||||
// Choose which claims drive groups/roles extraction. Prefer the ID
|
||||
// token (cached) and fall back to the access token if there is no ID
|
||||
// token in the session — matching the prior behavior for opaque
|
||||
// ID-token providers.
|
||||
var (
|
||||
groupClaims map[string]interface{}
|
||||
groupClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
groupClaims, groupClaimsErr = idClaims, idClaimsErr
|
||||
} else if accessToken := session.GetAccessToken(); accessToken != "" {
|
||||
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
|
||||
} else if len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if tokenForClaims != "" {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
// Reset redirect count to prevent loops when claim extraction fails
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
if groupClaimsErr != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
// Claims couldn't be extracted but roles checks are required:
|
||||
// re-authenticate rather than 403 (session may be salvageable on
|
||||
// re-issue). Bearer path uses 401 for the equivalent failure.
|
||||
t.logger.Errorf("Failed to extract claims for roles/groups check: %v", groupClaimsErr)
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Persist any dirty session state BEFORE forwardAuthorized writes the
|
||||
// response. Once next.ServeHTTP fires, Set-Cookie can no longer reach
|
||||
// the client. The forwardAuthorized pipeline does not mutate session
|
||||
// state, so saving here is safe.
|
||||
if session.IsDirty() {
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session after processing headers: %v", err)
|
||||
}
|
||||
} else {
|
||||
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
|
||||
}
|
||||
|
||||
// Build the source-agnostic principal. ID-token claims drive header
|
||||
// templates and roles when present; otherwise fall back to access-token
|
||||
// claims (matches prior behavior for opaque-ID-token providers).
|
||||
p := &principal{
|
||||
Source: sourceSession,
|
||||
Identifier: userIdentifier,
|
||||
AccessToken: session.GetAccessToken(),
|
||||
IDToken: idToken,
|
||||
RefreshToken: session.GetRefreshToken(),
|
||||
Claims: groupClaims,
|
||||
}
|
||||
|
||||
t.forwardAuthorized(rw, req, p)
|
||||
}
|
||||
|
||||
// forwardAuthorized completes the post-authentication pipeline shared by the
|
||||
// cookie/session path and the bearer-token path. It performs:
|
||||
//
|
||||
// 1. Roles/groups extraction from p.Claims (idempotent; existing
|
||||
// extractGroupsAndRolesFromClaims helper).
|
||||
// 2. allowedRolesAndGroups gate — writes a 403 and returns if denied.
|
||||
// 3. Identity-header injection (X-Forwarded-User, X-User-Groups, X-User-Roles,
|
||||
// plus X-Auth-Request-* when !minimalHeaders).
|
||||
// 4. Operator-defined header templates.
|
||||
// 5. Security headers (delegated to t.securityHeadersApplier or fallback).
|
||||
// 6. OIDC session-cookie strip (stripAuthCookies).
|
||||
// 7. Authorization header strip on bearer source when stripAuthorizationHeader.
|
||||
// 8. next.ServeHTTP.
|
||||
//
|
||||
// Session persistence is the CALLER's responsibility — it must happen before
|
||||
// this function so Set-Cookie reaches the response.
|
||||
// headerTemplateMaxLen bounds the length of a rendered operator-defined header
|
||||
// template before it is forwarded downstream. Generous enough for an
|
||||
// "Authorization: Bearer <jwt>" value but small enough to reject obviously
|
||||
// abusive output. Matches the input-validation default header cap (8KB).
|
||||
const headerTemplateMaxLen = 8192
|
||||
|
||||
// headerClaimMaxLen returns the maximum accepted length for a claim-derived
|
||||
// header value (principal identifier, group, role). Reuses the operator-
|
||||
// configured identifier cap (default 256) so a single setting governs both
|
||||
// auth paths; falls back to 256 when unset.
|
||||
func (t *TraefikOidc) headerClaimMaxLen() int {
|
||||
if t.maxIdentifierLength > 0 {
|
||||
return t.maxIdentifierLength
|
||||
}
|
||||
return 256
|
||||
}
|
||||
|
||||
// sanitizeHeaderClaimList drops any group/role value that fails claim
|
||||
// sanitization (control chars, bidi-override runes, the , ; = delimiters, or an
|
||||
// over-long value) and returns the surviving values. Failing closed on a bad
|
||||
// entry prevents header injection and stops an embedded comma from injecting
|
||||
// extra entries into the comma-joined header. headerName is used only for
|
||||
// debug logging — the value is never logged.
|
||||
func (t *TraefikOidc) sanitizeHeaderClaimList(values []string, headerName string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
safe := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
if clean, ok := sanitizeHeaderClaimValue(v, t.headerClaimMaxLen()); ok {
|
||||
safe = append(safe, clean)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping %s entry: value failed claim sanitization", headerName)
|
||||
}
|
||||
}
|
||||
return safe
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Request, p *principal) {
|
||||
var (
|
||||
groups, roles []string
|
||||
extractErr error
|
||||
)
|
||||
if p.Claims != nil {
|
||||
groups, roles, extractErr = t.extractGroupsAndRolesFromClaims(p.Claims)
|
||||
if extractErr != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
// Bearer path: 403 (caller already verified the token; principal
|
||||
// claims are present but malformed for roles purposes).
|
||||
// Cookie path can't reach here because processAuthorizedRequest
|
||||
// catches groupClaimsErr earlier.
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", extractErr)
|
||||
t.sendErrorResponse(rw, req, "Access denied", http.StatusForbidden)
|
||||
return
|
||||
} else if err == nil {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
if extractErr == nil {
|
||||
// Sanitize each group/role before it is joined into a comma-
|
||||
// delimited header. The cookie/session path does not otherwise
|
||||
// sanitize claim-derived values (the bearer path sanitizes its
|
||||
// identifier at construction), so a control char would enable
|
||||
// header injection and an embedded comma would inject extra
|
||||
// entries into the comma-joined header. Fail closed: drop any
|
||||
// value that does not pass.
|
||||
if safeGroups := t.sanitizeHeaderClaimList(groups, "X-User-Groups"); len(safeGroups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(safeGroups, ","))
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
if safeRoles := t.sanitizeHeaderClaimList(roles, "X-User-Roles"); len(safeRoles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(safeRoles, ","))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -374,60 +756,73 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
t.logger.Infof("User %s does not have any allowed roles or groups", p.Identifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
// Sanitize the principal identifier before injecting it into headers. The
|
||||
// bearer path already sanitizes its identifier at construction; the
|
||||
// cookie/session path does not, so a claim carrying control chars, bidi-
|
||||
// override runes, or , ; = could inject or spoof header content. Fail
|
||||
// closed: drop the identifier header(s) rather than forward a tainted value.
|
||||
safeIdentifier, identifierOK := sanitizeHeaderClaimValue(p.Identifier, t.headerClaimMaxLen())
|
||||
if identifierOK {
|
||||
req.Header.Set("X-Forwarded-User", safeIdentifier)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping X-Forwarded-User header: identifier failed claim sanitization")
|
||||
}
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
if identifierOK {
|
||||
req.Header.Set("X-Auth-Request-User", safeIdentifier)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping X-Auth-Request-User header: identifier failed claim sanitization")
|
||||
}
|
||||
if p.IDToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", p.IDToken)
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.headerTemplates) > 0 {
|
||||
claims, err := t.extractClaimsFunc(session.GetIDToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
|
||||
} else {
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": session.GetAccessToken(),
|
||||
"IDToken": session.GetIDToken(),
|
||||
"RefreshToken": session.GetRefreshToken(),
|
||||
"Claims": claims,
|
||||
}
|
||||
|
||||
for headerName, tmpl := range t.headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := tmpl.Execute(&buf, templateData); err != nil {
|
||||
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
|
||||
continue
|
||||
}
|
||||
headerValue := buf.String()
|
||||
|
||||
req.Header.Set(headerName, headerValue)
|
||||
|
||||
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
||||
}
|
||||
session.MarkDirty()
|
||||
t.logger.Debugf("Session marked dirty after templated header processing.")
|
||||
// p.Claims may be nil (e.g. session without an ID token). Templates
|
||||
// referencing .Claims.* will simply produce empty values — matches
|
||||
// the prior behavior. Bearer-source principals always carry access-
|
||||
// token claims (post-verifyToken).
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": p.AccessToken,
|
||||
"IDToken": p.IDToken,
|
||||
"RefreshToken": p.RefreshToken,
|
||||
"Claims": p.Claims,
|
||||
}
|
||||
}
|
||||
|
||||
if session.IsDirty() {
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session after processing headers: %v", err)
|
||||
for headerName, tmpl := range t.headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, templateData); err != nil {
|
||||
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
|
||||
continue
|
||||
}
|
||||
headerValue := buf.String()
|
||||
// Sanitize the rendered output: template inputs are claim-derived
|
||||
// and attacker-influenceable, so reject control chars (header
|
||||
// injection), bidi-override runes, the , ; = delimiters, and an
|
||||
// over-long value. Fail closed by dropping the header rather than
|
||||
// forwarding a tainted value. Do not log the value (it commonly
|
||||
// carries the access token); log only name + reason.
|
||||
if reason := headerValueReason(headerValue, headerTemplateMaxLen); reason != "" {
|
||||
t.logger.Debugf("Dropping templated header %s: value failed sanitization (%s)", headerName, reason)
|
||||
continue
|
||||
}
|
||||
req.Header.Set(headerName, headerValue)
|
||||
// Do not log the value: templated headers commonly carry the access
|
||||
// token (e.g. "Authorization: Bearer {{.AccessToken}}"), and logging
|
||||
// it — even at debug — leaks credentials into logs.
|
||||
t.logger.Debugf("Set templated header %s (%d bytes)", headerName, len(headerValue))
|
||||
}
|
||||
} else {
|
||||
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
|
||||
}
|
||||
|
||||
// Apply security headers if configured
|
||||
@@ -441,7 +836,30 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
}
|
||||
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
||||
// Strip OIDC session cookies before forwarding to the backend to prevent
|
||||
// HTTP 431 "Request Header Fields Too Large" errors (GitHub issue #122).
|
||||
if t.stripAuthCookies && t.sessionManager != nil {
|
||||
prefix := t.sessionManager.GetCookiePrefix()
|
||||
filtered := make([]*http.Cookie, 0, len(req.Cookies()))
|
||||
for _, c := range req.Cookies() {
|
||||
if !strings.HasPrefix(c.Name, prefix) {
|
||||
filtered = append(filtered, c)
|
||||
}
|
||||
}
|
||||
req.Header.Del("Cookie")
|
||||
for _, c := range filtered {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Bearer source: strip the Authorization header to keep the raw token
|
||||
// out of downstream service logs. Off-by-config for operators who chain
|
||||
// services that each re-verify the bearer.
|
||||
if p.Source == sourceBearer && t.stripAuthorizationHeader {
|
||||
req.Header.Del("Authorization")
|
||||
}
|
||||
|
||||
t.logger.Debugf("Request authorized for user %s (source=%d), forwarding to next handler", p.Identifier, p.Source)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -13,8 +13,8 @@ func TestMiddlewareContextCancellation(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate waiting
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
}
|
||||
|
||||
// Create request with canceled context
|
||||
@@ -39,8 +39,8 @@ func TestMiddlewareSessionErrorRecovery(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -73,8 +73,8 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -102,8 +102,8 @@ func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
postLogoutRedirectURI: "/",
|
||||
forceHTTPS: false,
|
||||
@@ -142,8 +142,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -161,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create authenticated session
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetIDToken("dummy-token")
|
||||
session.Save(req, httptest.NewRecorder())
|
||||
@@ -187,8 +187,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -203,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create session with forbidden domain
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@forbidden.com")
|
||||
session.SetUserIdentifier("user@forbidden.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save and inject cookies
|
||||
@@ -236,8 +236,8 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -252,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
// Create session with opaque token
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
@@ -291,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("") // No email
|
||||
session.SetUserIdentifier("") // No email
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -321,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("") // No ID token
|
||||
session.SetAccessToken("") // No access token
|
||||
|
||||
@@ -349,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -383,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
testEmail := "user@example.com"
|
||||
session.SetEmail(testEmail)
|
||||
session.SetUserIdentifier(testEmail)
|
||||
session.SetIDToken("dummy-id-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
// Package traefikoidc — principal abstraction for the shared post-auth
|
||||
// pipeline. A principal carries the resolved identity + tokens + claims
|
||||
// produced by EITHER the cookie session path or the bearer-token path, so
|
||||
// downstream header injection / roles checks / forwarding can be implemented
|
||||
// once and reused.
|
||||
package traefikoidc
|
||||
|
||||
// principalSource indicates which auth path produced a principal. Used by
|
||||
// forwardAuthorized to decide source-specific behavior (e.g. only strip the
|
||||
// Authorization header for bearer-source principals).
|
||||
type principalSource int
|
||||
|
||||
const (
|
||||
sourceSession principalSource = iota
|
||||
sourceBearer
|
||||
)
|
||||
|
||||
// principal is the immutable post-auth value passed to forwardAuthorized.
|
||||
// No methods mutate it; no manager pointer; no I/O. Pure data.
|
||||
type principal struct {
|
||||
Claims map[string]interface{}
|
||||
Identifier string
|
||||
Subject string
|
||||
ClientID string
|
||||
AccessToken string
|
||||
IDToken string
|
||||
RefreshToken string
|
||||
Source principalSource
|
||||
}
|
||||
|
||||
// buildPrincipalFromSession adapts an authenticated SessionData into a
|
||||
// principal value WITHOUT writing back to the session. This is the only
|
||||
// function that still knows about SessionData; the rest of the pipeline is
|
||||
// session-agnostic. Returns nil when the session has no usable identity.
|
||||
func (t *TraefikOidc) buildPrincipalFromSession(session *SessionData) *principal {
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
identifier := session.GetUserIdentifier()
|
||||
if identifier == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if idToken := session.GetIDToken(); idToken != "" && t.extractClaimsFunc != nil {
|
||||
// Best-effort: cached on the session, never blocking.
|
||||
claims, _ = session.GetIDTokenClaims(t.extractClaimsFunc) // Safe to ignore: claims-error path handled by header-template branch
|
||||
}
|
||||
|
||||
return &principal{
|
||||
Source: sourceSession,
|
||||
Identifier: identifier,
|
||||
AccessToken: session.GetAccessToken(),
|
||||
IDToken: session.GetIDToken(),
|
||||
RefreshToken: session.GetRefreshToken(),
|
||||
Claims: claims,
|
||||
}
|
||||
}
|
||||
+360
-230
@@ -15,19 +15,28 @@ import (
|
||||
// It implements request coalescing, rate limiting, and circuit breaking
|
||||
// specifically for token refresh operations.
|
||||
type RefreshCoordinator struct {
|
||||
inFlightRefreshes map[string]*refreshOperation
|
||||
cleanupTimers map[string]*time.Timer
|
||||
sessionRefreshAttempts map[string]*refreshAttemptTracker
|
||||
delayedCleanupQueue chan delayedCleanupItem
|
||||
// inFlightRefreshes maps tokenHash -> *refreshOperation. sync.Map is used
|
||||
// instead of a plain map + RWMutex so concurrent refreshes do not
|
||||
// serialize on a single global lock. Under Yaegi the previous
|
||||
// refreshMutex.Lock() was held for tens of milliseconds per request due
|
||||
// to interpreter overhead on the work inside the critical section,
|
||||
// causing dozens of goroutines to stack up on it and pin one CPU core.
|
||||
inFlightRefreshes sync.Map
|
||||
// sessionRefreshAttempts maps sessionID -> *refreshAttemptTracker.
|
||||
// sync.Map + atomic tracker fields means isInCooldown/recordRefreshAttempt/
|
||||
// recordRefreshSuccess/recordRefreshFailure are lock-free. Previously
|
||||
// these used attemptsMutex sync.RWMutex; under Yaegi every Lock() acquisition
|
||||
// adds 10-50ms of dispatch overhead, and they were called twice per leader
|
||||
// request (once for recordRefreshAttempt, once for isInCooldown). That
|
||||
// serializing pattern caused the v1.0.15 death spiral after v1.0.14
|
||||
// removed the refreshMutex (same architectural shape, different mutex).
|
||||
sessionRefreshAttempts sync.Map
|
||||
circuitBreaker *RefreshCircuitBreaker
|
||||
metrics *RefreshMetrics
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
config RefreshCoordinatorConfig
|
||||
wg sync.WaitGroup
|
||||
attemptsMutex sync.RWMutex
|
||||
refreshMutex sync.RWMutex
|
||||
cleanupTimerMu sync.Mutex
|
||||
}
|
||||
|
||||
// RefreshCoordinatorConfig configures the refresh coordinator behavior
|
||||
@@ -85,14 +94,46 @@ type refreshResult struct {
|
||||
fromCache bool
|
||||
}
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session
|
||||
type refreshAttemptTracker struct {
|
||||
lastAttemptTime time.Time
|
||||
windowStartTime time.Time
|
||||
cooldownEndTime time.Time
|
||||
// attemptState is the immutable snapshot of a session's refresh-attempt
|
||||
// state. Lives behind refreshAttemptTracker.state (atomic.Value). Every
|
||||
// transition (record, success, failure, window-reset, cooldown-enter,
|
||||
// cooldown-exit) constructs a fresh attemptState and publishes it via
|
||||
// CompareAndSwap so the entire field set is updated together.
|
||||
//
|
||||
// Per-field atomic.Load/Store (the previous v1.0.15 design) had a benign
|
||||
// but observable hazard: the cooldown-exit reset wrote cooldownEndNano = 0
|
||||
// first, then separately stored attempts = 1 and windowStartNano = now.
|
||||
// A concurrent isInCooldown call could see cooldownEndNano = 0 (reset
|
||||
// just completed) with attempts still at MaxRefreshAttempts, triggering
|
||||
// a fresh cooldown immediately. The snapshot approach eliminates the
|
||||
// intermediate state entirely.
|
||||
type attemptState struct {
|
||||
lastAttemptNano int64 // UnixNano of last attempt
|
||||
windowStartNano int64 // UnixNano of attempt-window start
|
||||
cooldownEndNano int64 // UnixNano; 0 = not in cooldown
|
||||
attempts int32
|
||||
consecutiveFailures int32
|
||||
inCooldown bool
|
||||
}
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session via a single
|
||||
// atomic.Value holding a *attemptState pointer. Readers do exactly one Load.
|
||||
// Writers do Load → construct new → CompareAndSwap (retry on conflict).
|
||||
// Under Yaegi this collapses 3-4 per-field atomic dispatches into one Load,
|
||||
// and eliminates the cross-field race in the window-reset path.
|
||||
type refreshAttemptTracker struct {
|
||||
state atomic.Value // *attemptState
|
||||
}
|
||||
|
||||
// stateOf returns the current attemptState, or a zero-value snapshot if none
|
||||
// has been published yet. The empty snapshot represents "no attempts recorded".
|
||||
func (t *refreshAttemptTracker) stateOf() *attemptState {
|
||||
if v := t.state.Load(); v != nil {
|
||||
s, _ := v.(*attemptState)
|
||||
if s != nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return &attemptState{}
|
||||
}
|
||||
|
||||
// RefreshMetrics tracks coordinator performance metrics
|
||||
@@ -107,20 +148,18 @@ type RefreshMetrics struct {
|
||||
currentInFlightRefreshes int32
|
||||
}
|
||||
|
||||
// delayedCleanupItem represents an item scheduled for delayed cleanup
|
||||
type delayedCleanupItem struct {
|
||||
cleanupAt time.Time
|
||||
tokenHash string
|
||||
}
|
||||
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh
|
||||
// operations. All mutable fields are atomic so AllowRequest/RecordSuccess/
|
||||
// RecordFailure run without any mutex. The previous sync.RWMutex.RLock() was
|
||||
// taken on every CoordinateRefresh — under Yaegi this added 10-50ms of
|
||||
// interpreter dispatch per call, which compounded with attemptsMutex to keep
|
||||
// the pod's single CPU core saturated.
|
||||
type RefreshCircuitBreaker struct {
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureNano int64 // atomic, UnixNano of most recent failure
|
||||
lastSuccessNano int64 // atomic, UnixNano of most recent success
|
||||
config RefreshCircuitBreakerConfig
|
||||
mutex sync.RWMutex
|
||||
state int32
|
||||
failures int32
|
||||
state int32 // atomic: 0=closed, 1=open, 2=half-open
|
||||
failures int32 // atomic
|
||||
}
|
||||
|
||||
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
|
||||
@@ -137,14 +176,12 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
|
||||
}
|
||||
|
||||
rc := &RefreshCoordinator{
|
||||
inFlightRefreshes: make(map[string]*refreshOperation),
|
||||
sessionRefreshAttempts: make(map[string]*refreshAttemptTracker),
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
delayedCleanupQueue: make(chan delayedCleanupItem, 1000), // Buffered channel for cleanup items
|
||||
cleanupTimers: make(map[string]*time.Timer),
|
||||
// inFlightRefreshes and sessionRefreshAttempts are both sync.Map;
|
||||
// their zero values are ready to use.
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
circuitBreaker: &RefreshCircuitBreaker{
|
||||
config: RefreshCircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
@@ -158,10 +195,6 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
|
||||
rc.wg.Add(1)
|
||||
go rc.cleanupRoutine()
|
||||
|
||||
// Start delayed cleanup processor (single goroutine processes all cleanup timers)
|
||||
rc.wg.Add(1)
|
||||
go rc.processDelayedCleanups()
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
@@ -234,18 +267,33 @@ func (rc *RefreshCoordinator) CoordinateRefresh(
|
||||
// Returns (operation, false, nil) if joined an existing operation
|
||||
// Returns (nil, false, error) if the operation was rejected
|
||||
func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
ctx context.Context,
|
||||
_ context.Context,
|
||||
sessionID string,
|
||||
tokenHash string,
|
||||
refreshToken string,
|
||||
) (*refreshOperation, bool, error) {
|
||||
rc.refreshMutex.Lock()
|
||||
defer rc.refreshMutex.Unlock()
|
||||
// Speculatively construct the operation we WOULD register if we win the
|
||||
// race. Allocating here keeps the LoadOrStore call below atomic and
|
||||
// avoids any global lock — under Yaegi the previous map+RWMutex design
|
||||
// held the write lock long enough (tens of ms per call) that concurrent
|
||||
// refreshes on the same coordinator serialized into a queue that grew
|
||||
// without bound. See struct comment on inFlightRefreshes.
|
||||
candidate := &refreshOperation{
|
||||
refreshToken: refreshToken,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
waiterCount: 1,
|
||||
}
|
||||
|
||||
// Check for existing operation while holding the lock
|
||||
if existingOp, exists := rc.inFlightRefreshes[tokenHash]; exists {
|
||||
if existing, loaded := rc.inFlightRefreshes.LoadOrStore(tokenHash, candidate); loaded {
|
||||
existingOp, ok := existing.(*refreshOperation)
|
||||
if !ok {
|
||||
// Defensive: anything stored here is always *refreshOperation, but
|
||||
// keep the typed assert so a programming error elsewhere doesn't
|
||||
// surface as a confusing panic in an interpreter frame.
|
||||
return nil, false, fmt.Errorf("inFlightRefreshes corrupt: unexpected type %T", existing)
|
||||
}
|
||||
if existingOp.refreshToken == refreshToken {
|
||||
// Join existing operation
|
||||
atomic.AddInt32(&existingOp.waiterCount, 1)
|
||||
return existingOp, false, nil
|
||||
}
|
||||
@@ -253,47 +301,77 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
return nil, false, fmt.Errorf("refresh token mismatch")
|
||||
}
|
||||
|
||||
// No existing operation - check if we can create a new one
|
||||
// All checks happen while holding the lock to prevent races
|
||||
// We won the race and registered `candidate`. Apply gates now. If any
|
||||
// gate fails we must remove our entry from the map and signal failure
|
||||
// to any joiners that snuck in between LoadOrStore and now.
|
||||
if err := rc.applyLeaderGates(sessionID); err != nil {
|
||||
rc.failCandidate(tokenHash, candidate, err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// Check and record refresh attempt for rate limiting
|
||||
rc.recordRefreshAttempt(sessionID)
|
||||
// Reserve concurrent slot via ticket-and-return: increment optimistically,
|
||||
// decrement if we overshot the limit. The previous CAS-loop allowed a
|
||||
// transient overshoot of up to N-1 leaders when several goroutines all
|
||||
// observed `current < max` in the same scheduling slice before any one
|
||||
// of them succeeded their CAS — visible to readers as
|
||||
// currentInFlightRefreshes > MaxConcurrentRefreshes for a brief window.
|
||||
// The ticket pattern is strictly bounded: the counter momentarily reads
|
||||
// max+k for k concurrent attempts past the limit, but only the k that
|
||||
// produced max+1..max+k decrement back, and only k=1 ever observes max+1
|
||||
// as committed.
|
||||
newCount := atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
|
||||
if int(newCount) > rc.config.MaxConcurrentRefreshes {
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
err := fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
rc.failCandidate(tokenHash, candidate, err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return candidate, true, nil
|
||||
}
|
||||
|
||||
// applyLeaderGates runs the rate-limit, cooldown, and memory-pressure checks
|
||||
// that previously ran under the global refreshMutex. Only the leader (the
|
||||
// goroutine that just registered the operation) runs them; joiners share the
|
||||
// leader's outcome via operation.done.
|
||||
func (rc *RefreshCoordinator) applyLeaderGates(sessionID string) error {
|
||||
// Cooldown check FIRST, BEFORE incrementing the attempt counter.
|
||||
// Previously this function recorded the attempt and then read the
|
||||
// cooldown state. Under burst load (many concurrent leaders with
|
||||
// different token hashes but same session) every goroutine could
|
||||
// increment past MaxRefreshAttempts before any one of them observed
|
||||
// the threshold, so the cooldown gate fired too late — the same
|
||||
// thundering-herd shape that drove v1.0.14 into the ground.
|
||||
if rc.isInCooldown(sessionID) {
|
||||
atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1)
|
||||
return nil, false, fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
|
||||
return fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
|
||||
}
|
||||
|
||||
// Check memory pressure
|
||||
if rc.config.EnableMemoryPressureDetection && rc.isUnderMemoryPressure() {
|
||||
atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1)
|
||||
return nil, false, fmt.Errorf("system under memory pressure, refresh denied")
|
||||
return fmt.Errorf("system under memory pressure, refresh denied")
|
||||
}
|
||||
// Only count attempts that actually progress past the gates.
|
||||
rc.recordRefreshAttempt(sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check and reserve concurrent refresh slot atomically
|
||||
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
|
||||
if int(current) >= rc.config.MaxConcurrentRefreshes {
|
||||
return nil, false, fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
}
|
||||
|
||||
// Reserve the slot - we're still holding the lock so this is safe
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
|
||||
|
||||
// Create and register new operation
|
||||
operation := &refreshOperation{
|
||||
refreshToken: refreshToken,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
waiterCount: 1,
|
||||
}
|
||||
rc.inFlightRefreshes[tokenHash] = operation
|
||||
|
||||
return operation, true, nil
|
||||
// failCandidate removes the leader's just-registered operation from the
|
||||
// in-flight map and signals the error to any joiners by recording the result
|
||||
// and closing the done channel. This keeps the (nil, false, err) return path
|
||||
// equivalent to the pre-sync.Map version: callers see the error directly,
|
||||
// joiners see it via operation.done.
|
||||
func (rc *RefreshCoordinator) failCandidate(tokenHash string, op *refreshOperation, err error) {
|
||||
rc.inFlightRefreshes.Delete(tokenHash)
|
||||
op.mutex.Lock()
|
||||
op.result = &refreshResult{err: err}
|
||||
op.mutex.Unlock()
|
||||
close(op.done)
|
||||
}
|
||||
|
||||
// executeRefreshAsync performs the actual refresh operation asynchronously
|
||||
func (rc *RefreshCoordinator) executeRefreshAsync(
|
||||
operation *refreshOperation,
|
||||
sessionID string,
|
||||
_ string, // sessionID - reserved for future metrics/logging
|
||||
tokenHash string,
|
||||
refreshFunc func() (*TokenResponse, error),
|
||||
) {
|
||||
@@ -350,159 +428,227 @@ func (rc *RefreshCoordinator) executeRefreshAsync(
|
||||
}
|
||||
}
|
||||
|
||||
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning a goroutine
|
||||
// This prevents goroutine explosion under high load (500+ req/sec)
|
||||
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning
|
||||
// a goroutine — time.AfterFunc uses the runtime's timer heap and never spawns
|
||||
// a per-timer goroutine until the callback actually fires.
|
||||
//
|
||||
// The previous implementation tracked every pending timer in a map guarded by
|
||||
// cleanupTimerMu so a duplicate scheduling could cancel the prior timer. That
|
||||
// "shouldn't happen" path was the only consumer of the map, but the mutex
|
||||
// fired on every successful refresh completion — yet another per-request
|
||||
// Yaegi-dispatched lock acquisition. performCleanup is already idempotent
|
||||
// (LoadAndDelete on the sync.Map), so a duplicate scheduling at worst fires
|
||||
// performCleanup twice; the second call is a no-op. Dropping the map removes
|
||||
// the whole class of contention on this code path.
|
||||
func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) {
|
||||
delay := rc.config.DeduplicationCleanupDelay
|
||||
if delay <= 0 {
|
||||
// Immediate cleanup
|
||||
rc.performCleanup(tokenHash)
|
||||
return
|
||||
}
|
||||
|
||||
// Use time.AfterFunc which is more efficient than spawning a goroutine with Sleep
|
||||
// time.AfterFunc uses the runtime's timer heap which is much more efficient
|
||||
rc.cleanupTimerMu.Lock()
|
||||
// Cancel any existing timer for this hash (shouldn't happen, but just in case)
|
||||
if existingTimer, exists := rc.cleanupTimers[tokenHash]; exists {
|
||||
existingTimer.Stop()
|
||||
}
|
||||
rc.cleanupTimers[tokenHash] = time.AfterFunc(delay, func() {
|
||||
rc.performCleanup(tokenHash)
|
||||
// Remove timer from map
|
||||
rc.cleanupTimerMu.Lock()
|
||||
delete(rc.cleanupTimers, tokenHash)
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
})
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
time.AfterFunc(delay, func() { rc.performCleanup(tokenHash) })
|
||||
}
|
||||
|
||||
// performCleanup removes the operation from the in-flight map
|
||||
// performCleanup removes the operation from the in-flight map.
|
||||
// Idempotent: only decrements the in-flight counter if an entry was actually
|
||||
// removed. LoadAndDelete is atomic so any concurrent failCandidate or repeat
|
||||
// cleanup call will see exactly one removal — the budget cannot be corrupted
|
||||
// by double-decrement.
|
||||
func (rc *RefreshCoordinator) performCleanup(tokenHash string) {
|
||||
rc.refreshMutex.Lock()
|
||||
delete(rc.inFlightRefreshes, tokenHash)
|
||||
rc.refreshMutex.Unlock()
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
if _, existed := rc.inFlightRefreshes.LoadAndDelete(tokenHash); existed {
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// processDelayedCleanups processes delayed cleanup requests from the queue
|
||||
// This is a single goroutine that handles all delayed cleanups
|
||||
func (rc *RefreshCoordinator) processDelayedCleanups() {
|
||||
defer rc.wg.Done()
|
||||
// getOrCreateTracker fetches the tracker for sessionID or atomically creates a
|
||||
// fresh one. The sync.Map.LoadOrStore semantics make this lock-free even under
|
||||
// concurrent first-touch races: at most one tracker per sessionID survives.
|
||||
//
|
||||
// trackerFromMapValue centralizes the type assertion so the lint-mandated
|
||||
// two-value form lives in one place; the stored type is always
|
||||
// *refreshAttemptTracker by construction.
|
||||
func trackerFromMapValue(v interface{}) *refreshAttemptTracker {
|
||||
t, _ := v.(*refreshAttemptTracker)
|
||||
return t
|
||||
}
|
||||
|
||||
func (rc *RefreshCoordinator) getOrCreateTracker(sessionID string) *refreshAttemptTracker {
|
||||
if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok {
|
||||
return trackerFromMapValue(v)
|
||||
}
|
||||
fresh := &refreshAttemptTracker{}
|
||||
fresh.state.Store(&attemptState{windowStartNano: time.Now().UnixNano()})
|
||||
actual, _ := rc.sessionRefreshAttempts.LoadOrStore(sessionID, fresh)
|
||||
return trackerFromMapValue(actual)
|
||||
}
|
||||
|
||||
// mutateState performs a CompareAndSwap loop that applies mutate to the
|
||||
// current snapshot. mutate must be PURE: it receives an immutable view of
|
||||
// the current state and returns a fresh *attemptState. If mutate returns nil
|
||||
// the update is skipped (used by isInCooldown for "no change needed" paths).
|
||||
//
|
||||
// Retries on CAS conflict are bounded by the number of concurrent writers —
|
||||
// in practice 1-3. Under Yaegi each retry pays the dispatch cost of one Load
|
||||
// + one CompareAndSwap; still cheaper than the previous per-field atomic
|
||||
// sequence and immune to the cross-field race the v1.0.15 design had.
|
||||
func (t *refreshAttemptTracker) mutateState(mutate func(cur *attemptState) *attemptState) *attemptState {
|
||||
for {
|
||||
select {
|
||||
case item := <-rc.delayedCleanupQueue:
|
||||
// Wait until cleanup time
|
||||
waitDuration := time.Until(item.cleanupAt)
|
||||
if waitDuration > 0 {
|
||||
select {
|
||||
case <-time.After(waitDuration):
|
||||
case <-rc.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
rc.performCleanup(item.tokenHash)
|
||||
case <-rc.stopChan:
|
||||
return
|
||||
cur := t.stateOf()
|
||||
next := mutate(cur)
|
||||
if next == nil {
|
||||
return cur
|
||||
}
|
||||
if t.state.CompareAndSwap(cur, next) {
|
||||
return next
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isInCooldown checks if a session is in cooldown after recording an attempt
|
||||
// isInCooldown checks if a session is in cooldown. Snapshot-based: every
|
||||
// transition publishes a fresh *attemptState atomically so readers never see
|
||||
// a partially-updated state. The previous per-field atomic design had a
|
||||
// benign race in the cooldown-exit path (cooldownEndNano reset before
|
||||
// attempts reset) that could double-trigger cooldown.
|
||||
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return false // No tracker means first attempt, not in cooldown
|
||||
}
|
||||
|
||||
tracker := trackerFromMapValue(v)
|
||||
now := time.Now()
|
||||
nowNano := now.UnixNano()
|
||||
maxAttempts := rc.config.MaxRefreshAttempts
|
||||
window := rc.config.RefreshAttemptWindow
|
||||
cooldownPeriod := rc.config.RefreshCooldownPeriod
|
||||
|
||||
// Check if already in cooldown
|
||||
if tracker.inCooldown {
|
||||
if now.After(tracker.cooldownEndTime) {
|
||||
// Cooldown expired, reset tracker
|
||||
tracker.inCooldown = false
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.consecutiveFailures = 0
|
||||
tracker.windowStartTime = now
|
||||
return false
|
||||
cur := tracker.stateOf()
|
||||
|
||||
// Already in cooldown?
|
||||
if cur.cooldownEndNano != 0 {
|
||||
if nowNano <= cur.cooldownEndNano {
|
||||
return true // still in cooldown
|
||||
}
|
||||
return true // Still in cooldown
|
||||
}
|
||||
|
||||
// Check if window expired
|
||||
if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow {
|
||||
// Reset window
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.windowStartTime = now
|
||||
// Cooldown expired: atomically publish a fresh state with the window
|
||||
// restarted from one attempt. Whichever goroutine wins the CAS sets
|
||||
// the new snapshot; losers see it via the next stateOf load.
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if s.cooldownEndNano == 0 || nowNano <= s.cooldownEndNano {
|
||||
return nil // someone else already reset, or back in cooldown
|
||||
}
|
||||
return &attemptState{
|
||||
windowStartNano: nowNano,
|
||||
attempts: 1,
|
||||
}
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if just exceeded attempt limit
|
||||
if int(tracker.attempts) >= rc.config.MaxRefreshAttempts {
|
||||
// Enter cooldown now
|
||||
tracker.inCooldown = true
|
||||
tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod)
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, tracker.attempts)
|
||||
// Window expired?
|
||||
if time.Duration(nowNano-cur.windowStartNano) > window {
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if time.Duration(nowNano-s.windowStartNano) <= window {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.windowStartNano = nowNano
|
||||
next.attempts = 1
|
||||
return &next
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Just exceeded attempt limit?
|
||||
if int(cur.attempts) >= maxAttempts {
|
||||
end := now.Add(cooldownPeriod).UnixNano()
|
||||
published := tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if s.cooldownEndNano != 0 {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.cooldownEndNano = end
|
||||
return &next
|
||||
})
|
||||
if published.cooldownEndNano == end {
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, published.attempts)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting. Lock-free
|
||||
// snapshot mutation; attempts and lastAttemptNano are advanced atomically.
|
||||
func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
tracker = &refreshAttemptTracker{
|
||||
windowStartTime: time.Now(),
|
||||
}
|
||||
rc.sessionRefreshAttempts[sessionID] = tracker
|
||||
}
|
||||
|
||||
atomic.AddInt32(&tracker.attempts, 1)
|
||||
tracker.lastAttemptTime = time.Now()
|
||||
tracker := rc.getOrCreateTracker(sessionID)
|
||||
nowNano := time.Now().UnixNano()
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
next := *s
|
||||
next.attempts++
|
||||
next.lastAttemptNano = nowNano
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// recordRefreshSuccess records a successful refresh
|
||||
// recordRefreshSuccess records a successful refresh: zero consecutiveFailures.
|
||||
func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
tracker.consecutiveFailures = 0
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
|
||||
if s.consecutiveFailures == 0 {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.consecutiveFailures = 0
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// recordRefreshFailure records a failed refresh
|
||||
// recordRefreshFailure records a failed refresh: increments consecutiveFailures.
|
||||
func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
atomic.AddInt32(&tracker.consecutiveFailures, 1)
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
|
||||
next := *s
|
||||
next.consecutiveFailures++
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
|
||||
return refreshCoordinatorSessionID(token)
|
||||
}
|
||||
|
||||
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
|
||||
// for both deduplication and per-session attempt tracking. Using sha256 of the
|
||||
// raw token means each rotation produces a fresh sessionID with its own attempt
|
||||
// budget, which is what we want.
|
||||
func refreshCoordinatorSessionID(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure
|
||||
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
|
||||
// coordinated refresh result. It is wider than RefreshTimeout so a follower
|
||||
// always sees the leader's result instead of timing out independently.
|
||||
const refreshCoordinatorWaitTimeout = 35 * time.Second
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure by
|
||||
// consulting the global memory monitor. Returns true when pressure reaches
|
||||
// High or Critical, at which point we refuse new refresh operations to
|
||||
// avoid aggravating an already-stressed heap.
|
||||
func (rc *RefreshCoordinator) isUnderMemoryPressure() bool {
|
||||
// This is a simplified check - in production you'd want to use runtime.MemStats
|
||||
// or system-specific memory monitoring
|
||||
return false // Placeholder - implement actual memory check
|
||||
monitor := GetGlobalMemoryMonitor()
|
||||
if monitor == nil {
|
||||
return false
|
||||
}
|
||||
return monitor.GetMemoryPressure() >= MemoryPressureHigh
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically cleans up stale tracking entries
|
||||
@@ -522,20 +668,22 @@ func (rc *RefreshCoordinator) cleanupRoutine() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleEntries removes outdated tracking entries
|
||||
// cleanupStaleEntries removes outdated tracking entries. Lock-free iteration
|
||||
// via sync.Map.Range; safe to race with concurrent reads/writes.
|
||||
func (rc *RefreshCoordinator) cleanupStaleEntries() {
|
||||
now := time.Now()
|
||||
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
// Clean up old session trackers
|
||||
for sessionID, tracker := range rc.sessionRefreshAttempts {
|
||||
// Remove trackers that haven't been used recently
|
||||
if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow {
|
||||
delete(rc.sessionRefreshAttempts, sessionID)
|
||||
cutoff := time.Now().Add(-2 * rc.config.RefreshAttemptWindow).UnixNano()
|
||||
rc.sessionRefreshAttempts.Range(func(key, value interface{}) bool {
|
||||
tracker := trackerFromMapValue(value)
|
||||
if tracker == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if tracker.stateOf().lastAttemptNano < cutoff {
|
||||
// Compare-and-delete to avoid evicting a tracker that was just
|
||||
// re-used by a concurrent caller. We compare by pointer identity.
|
||||
rc.sessionRefreshAttempts.CompareAndDelete(key, value)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// GetMetrics returns current coordinator metrics
|
||||
@@ -553,78 +701,60 @@ func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the coordinator
|
||||
// Shutdown gracefully shuts down the coordinator. Pending delayed-cleanup
|
||||
// timers are NOT canceled explicitly: time.AfterFunc callbacks are tiny
|
||||
// (one map LoadAndDelete) and harmless after Shutdown — sync.Map operations
|
||||
// remain safe on an unused coordinator until GC.
|
||||
func (rc *RefreshCoordinator) Shutdown() {
|
||||
close(rc.stopChan)
|
||||
|
||||
// Cancel all pending cleanup timers
|
||||
rc.cleanupTimerMu.Lock()
|
||||
for _, timer := range rc.cleanupTimers {
|
||||
timer.Stop()
|
||||
}
|
||||
rc.cleanupTimers = make(map[string]*time.Timer)
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
|
||||
rc.wg.Wait()
|
||||
}
|
||||
|
||||
// AllowRequest checks if the circuit breaker allows a request
|
||||
// AllowRequest reports whether the circuit breaker allows a request. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) AllowRequest() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
switch state {
|
||||
case 0: // Closed
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 0: // closed
|
||||
return true
|
||||
case 1: // Open
|
||||
if time.Since(cb.lastFailureTime) > cb.config.OpenDuration {
|
||||
// Try to transition to half-open
|
||||
case 1: // open
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailureNano)
|
||||
if time.Duration(time.Now().UnixNano()-lastFail) > cb.config.OpenDuration {
|
||||
// Transition to half-open; first CAS winner gets the probe.
|
||||
if atomic.CompareAndSwapInt32(&cb.state, 1, 2) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case 2: // Half-open
|
||||
case 2: // half-open
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
// RecordSuccess records a successful operation. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) RecordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == 2 { // Half-open
|
||||
// Close the circuit
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 2: // half-open -> close
|
||||
atomic.StoreInt32(&cb.state, 0)
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
} else if state == 0 { // Closed
|
||||
// Reset failure count on success
|
||||
case 0: // closed
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
}
|
||||
cb.lastSuccessTime = time.Now()
|
||||
atomic.StoreInt64(&cb.lastSuccessNano, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
// RecordFailure records a failed operation. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) RecordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
failures := atomic.AddInt32(&cb.failures, 1)
|
||||
cb.lastFailureTime = time.Now()
|
||||
atomic.StoreInt64(&cb.lastFailureNano, time.Now().UnixNano())
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
if state == 0 && int(failures) >= cb.config.MaxFailures {
|
||||
// Open the circuit
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
} else if state == 2 {
|
||||
// Half-open failed, return to open
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 0:
|
||||
if int(failures) >= cb.config.MaxFailures {
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
case 2:
|
||||
// Half-open probe failed -> back to open.
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
}
|
||||
|
||||
+28
-34
@@ -165,9 +165,14 @@ func TestRefreshRateLimiting(t *testing.T) {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Verify that cooldown was triggered after max attempts
|
||||
// With the new logic, the Nth attempt triggers cooldown, so we get N-1 successful attempts
|
||||
expectedSuccessfulAttempts := config.MaxRefreshAttempts - 1
|
||||
// Verify that cooldown was triggered after max attempts.
|
||||
// With applyLeaderGates checking cooldown BEFORE recording the attempt
|
||||
// (the v1.0.16 reorder fixing the thundering-herd off-by-one), N attempts
|
||||
// run to completion and the (N+1)th is denied. Previously the Nth was
|
||||
// denied as it tried to record, which under burst load let multiple
|
||||
// concurrent leaders increment past the limit before any one of them
|
||||
// observed the gate.
|
||||
expectedSuccessfulAttempts := config.MaxRefreshAttempts
|
||||
if attempts != expectedSuccessfulAttempts {
|
||||
t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts)
|
||||
}
|
||||
@@ -365,10 +370,12 @@ func TestMemoryLeakPrevention(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify cleanup is working
|
||||
coordinator.attemptsMutex.RLock()
|
||||
sessionCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
// Verify cleanup is working. sync.Map has no Len(); count via Range.
|
||||
sessionCount := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
sessionCount++
|
||||
return true
|
||||
})
|
||||
|
||||
// Should have cleaned up old sessions (only recent ones remain)
|
||||
if sessionCount > numWorkers*2 {
|
||||
@@ -650,24 +657,23 @@ func TestCleanupRoutine(t *testing.T) {
|
||||
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
|
||||
}
|
||||
|
||||
// Verify sessions exist
|
||||
coordinator.attemptsMutex.RLock()
|
||||
initialCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
countSessions := func() int {
|
||||
n := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
n++
|
||||
return true
|
||||
})
|
||||
return n
|
||||
}
|
||||
|
||||
if initialCount != 5 {
|
||||
if initialCount := countSessions(); initialCount != 5 {
|
||||
t.Errorf("Expected 5 sessions, got %d", initialCount)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run (2x window + cleanup interval)
|
||||
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
|
||||
|
||||
// Verify sessions were cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
finalCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
|
||||
if finalCount != 0 {
|
||||
if finalCount := countSessions(); finalCount != 0 {
|
||||
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
|
||||
}
|
||||
}
|
||||
@@ -720,11 +726,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
|
||||
currentGoroutines := runtime.NumGoroutine()
|
||||
t.Logf("Goroutines after %d refresh operations: %d", numRefreshes, currentGoroutines)
|
||||
|
||||
// Check timer count
|
||||
coordinator.cleanupTimerMu.Lock()
|
||||
timerCount := len(coordinator.cleanupTimers)
|
||||
coordinator.cleanupTimerMu.Unlock()
|
||||
t.Logf("Active cleanup timers: %d", timerCount)
|
||||
// (Coordinator no longer tracks pending timers; time.AfterFunc closures
|
||||
// fire performCleanup directly. This test now only checks the goroutine
|
||||
// budget, which was always the real invariant.)
|
||||
|
||||
// With timer-based cleanup, goroutine increase should be minimal
|
||||
// Timers don't create goroutines - they use the runtime timer heap
|
||||
@@ -740,19 +744,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
|
||||
initialGoroutines, currentGoroutines, goroutineIncrease)
|
||||
}
|
||||
|
||||
// Wait for timers to fire and cleanup
|
||||
// Wait for timers to fire and cleanup.
|
||||
time.Sleep(config.DeduplicationCleanupDelay + 50*time.Millisecond)
|
||||
|
||||
// Verify timers were cleaned up
|
||||
coordinator.cleanupTimerMu.Lock()
|
||||
remainingTimers := len(coordinator.cleanupTimers)
|
||||
coordinator.cleanupTimerMu.Unlock()
|
||||
|
||||
// Most timers should have fired and been removed
|
||||
if remainingTimers > 10 {
|
||||
t.Errorf("Too many cleanup timers remaining: %d", remainingTimers)
|
||||
}
|
||||
|
||||
// Verify goroutines returned to near initial
|
||||
runtime.GC()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// stubTokenExchanger lets us count how many upstream refresh-token grants
|
||||
// happen for a given refresh_token across concurrent middleware-level calls.
|
||||
type stubTokenExchanger struct {
|
||||
calls int32
|
||||
delay time.Duration
|
||||
resp *TokenResponse
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.delay > 0 {
|
||||
time.Sleep(s.delay)
|
||||
}
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many
|
||||
// concurrent calls to coordinatedTokenRefresh with the same refresh token
|
||||
// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call.
|
||||
//
|
||||
// Without the wireup this assertion fails (one upstream call per goroutine).
|
||||
func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 100 * time.Millisecond,
|
||||
resp: &TokenResponse{
|
||||
AccessToken: "new_access",
|
||||
RefreshToken: "new_refresh",
|
||||
IDToken: "new_id",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const concurrency = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "new_access" {
|
||||
t.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
got := atomic.LoadInt32(&stub.calls)
|
||||
// Up to 2 is acceptable to absorb the documented timing slack in the
|
||||
// existing coordinator tests (e.g. operation just cleaned up before a
|
||||
// late goroutine reads the in-flight map). Anything beyond that means
|
||||
// coalescing is broken.
|
||||
if got > 2 {
|
||||
t.Fatalf("expected <=2 upstream refresh calls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil
|
||||
// coordinator path so existing tests that build TraefikOidc literals stay
|
||||
// green.
|
||||
func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenExchanger: stub,
|
||||
// refreshCoordinator deliberately nil
|
||||
}
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, "rt")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "ok" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected exactly 1 upstream call, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that
|
||||
// distinct refresh tokens are not falsely coalesced.
|
||||
func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 20 * time.Millisecond,
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
cfg.DeduplicationCleanupDelay = 0
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const distinct = 8
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(distinct)
|
||||
for i := 0; i < distinct; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i))))
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(&stub.calls); int(got) != distinct {
|
||||
t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// inMemoryCache is the smallest CacheInterface that satisfies the cross-
|
||||
// replica dedup contract: Set/Get with TTL. Used in place of the universal
|
||||
// cache singleton so these tests stay hermetic.
|
||||
type inMemoryCache struct {
|
||||
entries map[string]inMemoryCacheEntry
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type inMemoryCacheEntry struct {
|
||||
expiresAt time.Time
|
||||
value interface{}
|
||||
}
|
||||
|
||||
func newInMemoryCache() *inMemoryCache {
|
||||
return &inMemoryCache{entries: make(map[string]inMemoryCacheEntry)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries[key] = inMemoryCacheEntry{value: value, expiresAt: time.Now().Add(ttl)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Get(key string) (any, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
e, ok := c.entries[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(e.expiresAt) {
|
||||
delete(c.entries, key)
|
||||
return nil, false
|
||||
}
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.entries, key)
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) SetMaxSize(int) {}
|
||||
func (c *inMemoryCache) Cleanup() {}
|
||||
func (c *inMemoryCache) Close() {}
|
||||
func (c *inMemoryCache) Size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.entries)
|
||||
}
|
||||
func (c *inMemoryCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries = map[string]inMemoryCacheEntry{}
|
||||
}
|
||||
func (c *inMemoryCache) GetStats() map[string]any { return map[string]any{} }
|
||||
|
||||
// erroringTokenExchanger always errors - simulates an IdP rejection.
|
||||
type erroringTokenExchanger struct {
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&e.calls, 1)
|
||||
return nil, errors.New("invalid_grant")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) RevokeTokenWithProvider(_, _ string) error { return nil }
|
||||
|
||||
// TestCoordinatedTokenRefresh_CrossReplicaCacheHit simulates a peer Traefik
|
||||
// replica having just refreshed: the shared cache already has the result, so
|
||||
// this pod must reuse it without ever calling the IdP.
|
||||
func TestCoordinatedTokenRefresh_CrossReplicaCacheHit(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "should_not_be_called"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
preExisting := &TokenResponse{
|
||||
AccessToken: "from_peer",
|
||||
RefreshToken: "rotated_by_peer",
|
||||
IDToken: "id_from_peer",
|
||||
}
|
||||
rt := "shared_refresh_token"
|
||||
cache.Set(refreshResultCacheKey(refreshCoordinatorSessionID(rt)), preExisting, refreshResultCacheTTL)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(httptest.NewRequest("GET", "/", nil), rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "from_peer" {
|
||||
t.Fatalf("expected peer-provided response, got %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 0 {
|
||||
t.Fatalf("expected 0 upstream calls (peer already refreshed), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache verifies that on a
|
||||
// cache miss the leader stores its result for peers to find within the TTL.
|
||||
func TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "fresh_grant"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
rt := "fresh_refresh_token"
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected 1 upstream call, got %d", got)
|
||||
}
|
||||
|
||||
v, ok := cache.Get(refreshResultCacheKey(refreshCoordinatorSessionID(rt)))
|
||||
if !ok {
|
||||
t.Fatal("expected refresh result to be cached after upstream success")
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); !ok || tr.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("cached value malformed: %+v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_ErrorIsNotCached makes sure we don't poison the
|
||||
// dedup cache when the IdP rejects the grant. Peers must run their own
|
||||
// refresh; they cannot inherit an error.
|
||||
func TestCoordinatedTokenRefresh_ErrorIsNotCached(t *testing.T) {
|
||||
failing := &erroringTokenExchanger{}
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: failing,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
if _, err := oidc.coordinatedTokenRefresh(nil, "doomed_refresh_token"); err == nil {
|
||||
t.Fatal("expected an error from the failing exchanger")
|
||||
}
|
||||
if cache.Size() != 0 {
|
||||
t.Fatalf("error result must not be cached, size=%d", cache.Size())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// sessionWithIssuedAt builds the smallest SessionData that GetRefreshTokenIssuedAt
|
||||
// reads from. We can't reuse sessionPool.Get() here because that requires a
|
||||
// fully initialized SessionManager - overkill for this unit-level check.
|
||||
func sessionWithIssuedAt(t *testing.T, issuedAt time.Time) *SessionData {
|
||||
t.Helper()
|
||||
rs := sessions.NewSession(nil, "refresh")
|
||||
if !issuedAt.IsZero() {
|
||||
rs.Values["issued_at"] = issuedAt.Unix()
|
||||
}
|
||||
return &SessionData{
|
||||
refreshSession: rs,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
idTokenChunks: make(map[int]*sessions.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_DisabledWhenAgeZero(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 0}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-30*24*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when maxRefreshTokenAge is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Time{}) // no issued_at value
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when issued_at missing (legacy session)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_WithinWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-1*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false within max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_BeyondWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-7*time.Hour))
|
||||
if !tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=true beyond max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_NilGuards(t *testing.T) {
|
||||
var tr *TraefikOidc
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil receiver must not panic and must return false")
|
||||
}
|
||||
tr = &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil session must return false")
|
||||
}
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
|
||||
// Simulate successful Azure authentication
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Azure may use opaque access tokens
|
||||
session.SetAccessToken("opaque-azure-access-token")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
|
||||
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
|
||||
assert.Equal(t, "user@example.com", session2.GetEmail())
|
||||
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
|
||||
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
|
||||
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
|
||||
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// requestState bundles read-mostly fields for a single ServeHTTP call.
|
||||
package traefikoidc
|
||||
|
||||
import "net/http"
|
||||
|
||||
// requestState is a per-request context object allocated at the top of
|
||||
// ServeHTTP and threaded through to downstream handlers. It caches values
|
||||
// that would otherwise require a Yaegi-dispatched lock acquisition each time
|
||||
// they're read:
|
||||
//
|
||||
// - The metadata snapshot (atomic.Value.Load once, not per-handler).
|
||||
// - SessionData getter results (one RLock on sd.sessionMutex covers all
|
||||
// fields, instead of 5-7 separate RLock/RUnlock pairs scattered through
|
||||
// the handler chain).
|
||||
//
|
||||
// The struct is alloc'd at request entry, populated under at most one RLock
|
||||
// of sd.sessionMutex, and discarded at request exit. It is NOT shared across
|
||||
// requests and never written from another goroutine, so no synchronization
|
||||
// on its fields is required.
|
||||
//
|
||||
// Cross-request global caches (tokenCache, JWKCache, sessionEntries,
|
||||
// sessionInvalidationCache) remain — they're orthogonal. requestState's job
|
||||
// is to eliminate redundant per-handler reads of values that don't change
|
||||
// within a single request.
|
||||
type requestState struct {
|
||||
// Globals snapshotted once.
|
||||
metadata *MetadataSnapshot
|
||||
|
||||
// SessionData fields snapshotted under one RLock. The pointer to the
|
||||
// SessionData is retained so handlers that genuinely need to mutate
|
||||
// (Save, Clear, etc.) still have access.
|
||||
session *SessionData
|
||||
|
||||
authenticated bool
|
||||
accessToken string
|
||||
idToken string
|
||||
refreshToken string
|
||||
userIdentifier string
|
||||
createdAtUnixSec int64
|
||||
|
||||
// Output: scheme/host/redirect path determined at top of ServeHTTP.
|
||||
scheme string
|
||||
host string
|
||||
redirectURL string
|
||||
|
||||
// Carry the next handler so forwardAuthorized doesn't need to close over t.
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// captureSession populates requestState's SessionData-derived fields under a
|
||||
// single RLock of sd.sessionMutex. Returns the populated rs for chaining.
|
||||
//
|
||||
// Replaces a sequence of SessionData.GetX() calls each of which acquires
|
||||
// sd.sessionMutex.RLock(). Under Yaegi each RLock costs ~1-5ms of
|
||||
// interpreter dispatch; batching saves the rest.
|
||||
func (rs *requestState) captureSession(sd *SessionData) *requestState {
|
||||
if sd == nil {
|
||||
return rs
|
||||
}
|
||||
rs.session = sd
|
||||
sd.sessionMutex.RLock()
|
||||
rs.authenticated = sd.getAuthenticatedUnsafe()
|
||||
rs.accessToken = sd.getAccessTokenUnsafe()
|
||||
rs.idToken = sd.getIDTokenUnsafe()
|
||||
rs.refreshToken = sd.getRefreshTokenUnsafe()
|
||||
rs.userIdentifier = sd.getUserIdentifierUnsafe()
|
||||
rs.createdAtUnixSec = sd.getCreatedAtUnsafe()
|
||||
sd.sessionMutex.RUnlock()
|
||||
return rs
|
||||
}
|
||||
@@ -0,0 +1,404 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// TestRank1_SessionCookieIsEncrypted verifies that the session cookie payload is
|
||||
// AES-encrypted, not merely HMAC-signed. Regression test for the audit finding
|
||||
// "session cookies signed but NOT encrypted": a single key left the stored OIDC
|
||||
// tokens recoverable in plaintext from the raw cookie bytes.
|
||||
func TestRank1_SessionCookieIsEncrypted(t *testing.T) {
|
||||
const secret = "a-sufficiently-long-session-encryption-key"
|
||||
authKey, encKey := deriveCookieKeys(secret)
|
||||
if len(authKey) != 64 || len(encKey) != 32 {
|
||||
t.Fatalf("expected 64-byte auth key and 32-byte enc key, got %d/%d", len(authKey), len(encKey))
|
||||
}
|
||||
if string(authKey) == string(encKey) {
|
||||
t.Fatal("authentication and encryption keys must be independent")
|
||||
}
|
||||
|
||||
const marker = "SUPER-SECRET-ACCESS-TOKEN-marker-value"
|
||||
|
||||
// Encode a session through the same two-key store the production code now
|
||||
// builds (see NewSessionManager).
|
||||
store := sessions.NewCookieStore(authKey, encKey)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
sess, err := store.New(req, "session")
|
||||
if err != nil {
|
||||
t.Fatalf("store.New failed: %v", err)
|
||||
}
|
||||
sess.Values["tok"] = marker
|
||||
if err := sess.Save(req, rec); err != nil {
|
||||
t.Fatalf("session save failed: %v", err)
|
||||
}
|
||||
|
||||
var cookie *http.Cookie
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
if c.Name == "session" {
|
||||
cookie = c
|
||||
}
|
||||
}
|
||||
if cookie == nil {
|
||||
t.Fatal("no session cookie was set")
|
||||
}
|
||||
|
||||
// The secret token must never appear in plaintext in the cookie value.
|
||||
if strings.Contains(cookie.Value, marker) {
|
||||
t.Error("marker token found in plaintext inside the session cookie value")
|
||||
}
|
||||
|
||||
// A store holding only the authentication key (the previous behavior)
|
||||
// must NOT be able to read the encrypted cookie — proving the payload is
|
||||
// genuinely encrypted, not just signed.
|
||||
signedOnly := sessions.NewCookieStore(authKey)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2.AddCookie(cookie)
|
||||
if _, derr := signedOnly.Get(req2, "session"); derr == nil {
|
||||
t.Error("encrypted cookie should not be decodable without the encryption key")
|
||||
}
|
||||
|
||||
// The full two-key store round-trips correctly.
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req3.AddCookie(cookie)
|
||||
rt, derr := store.Get(req3, "session")
|
||||
if derr != nil {
|
||||
t.Fatalf("round-trip decode failed: %v", derr)
|
||||
}
|
||||
if got, _ := rt.Values["tok"].(string); got != marker {
|
||||
t.Errorf("round-trip mismatch: got %q want %q", got, marker)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank2And6_InvalidConfigFailsClosed verifies that NewWithContext now calls
|
||||
// Config.Validate() and fails closed on an empty or too-short session
|
||||
// encryption key instead of silently substituting a public hardcoded key, and
|
||||
// rejects other missing required fields. Regression test for "hardcoded default
|
||||
// encryption key" + "Config.Validate() never called in production path".
|
||||
func TestRank2And6_InvalidConfigFailsClosed(t *testing.T) {
|
||||
base := func() *Config {
|
||||
return &Config{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "this-is-a-valid-session-key-32b!",
|
||||
RateLimit: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity: a fully valid config still constructs.
|
||||
p, err := NewWithContext(context.Background(), base(), nil, "valid")
|
||||
if err != nil {
|
||||
t.Fatalf("valid config should construct, got: %v", err)
|
||||
}
|
||||
if p != nil {
|
||||
p.Close()
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
}{
|
||||
{"empty key", func(c *Config) { c.SessionEncryptionKey = "" }},
|
||||
{"short key", func(c *Config) { c.SessionEncryptionKey = "tooshort" }},
|
||||
{"missing providerURL", func(c *Config) { c.ProviderURL = "" }},
|
||||
{"missing callbackURL", func(c *Config) { c.CallbackURL = "" }},
|
||||
{"plaintext remote providerURL", func(c *Config) { c.ProviderURL = "http://accounts.google.com" }},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := base()
|
||||
tc.mutate(c)
|
||||
plugin, err := NewWithContext(context.Background(), c, nil, tc.name)
|
||||
if err == nil {
|
||||
if plugin != nil {
|
||||
plugin.Close()
|
||||
}
|
||||
t.Errorf("expected NewWithContext to reject config (%s), but it succeeded", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank3_DiscoveredEndpointSSRFGuard verifies that endpoints from the
|
||||
// provider discovery document are screened against SSRF targets before use.
|
||||
func TestRank3_DiscoveredEndpointSSRFGuard(t *testing.T) {
|
||||
tr := &TraefikOidc{}
|
||||
|
||||
blocked := []string{
|
||||
"http://169.254.169.254/latest/meta-data/", // cloud metadata (link-local)
|
||||
"http://[fe80::1]/jwks", // IPv6 link-local
|
||||
"http://10.0.0.5/jwks", // private
|
||||
"http://192.168.1.10/jwks", // private
|
||||
"http://127.0.0.1/jwks", // loopback (allowLoopback=false)
|
||||
"ftp://example.com/jwks", // disallowed scheme
|
||||
}
|
||||
for _, u := range blocked {
|
||||
if err := tr.validateDiscoveredEndpoint(u, false); err == nil {
|
||||
t.Errorf("expected discovered endpoint %q to be rejected", u)
|
||||
}
|
||||
}
|
||||
|
||||
allowed := []string{
|
||||
"https://accounts.google.com/o/oauth2/v3/certs",
|
||||
"https://www.googleapis.com/oauth2/v3/certs", // cross-domain JWKS must stay allowed
|
||||
"", // empty optional endpoint
|
||||
}
|
||||
for _, u := range allowed {
|
||||
if err := tr.validateDiscoveredEndpoint(u, false); err != nil {
|
||||
t.Errorf("expected discovered endpoint %q to be allowed, got %v", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Loopback is allowed only when the provider itself is loopback (dev/test).
|
||||
if err := tr.validateDiscoveredEndpoint("http://127.0.0.1:8080/jwks", true); err != nil {
|
||||
t.Errorf("loopback endpoint should be allowed when allowLoopback=true: %v", err)
|
||||
}
|
||||
// Private addresses are allowed when explicitly opted in.
|
||||
trPriv := &TraefikOidc{allowPrivateIPAddresses: true}
|
||||
if err := trPriv.validateDiscoveredEndpoint("http://10.0.0.5/jwks", false); err != nil {
|
||||
t.Errorf("private endpoint should be allowed when allowPrivateIPAddresses=true: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank4_IntrospectionHostPin verifies the host-equality check used to pin
|
||||
// the credential-bearing introspection endpoint to the configured provider.
|
||||
func TestRank4_IntrospectionHostPin(t *testing.T) {
|
||||
if !sameHost("https://kc.example.com/realms/x", "https://kc.example.com/realms/x/protocol/openid-connect/token/introspect") {
|
||||
t.Error("introspection on the same host as the provider should be accepted")
|
||||
}
|
||||
if sameHost("https://kc.example.com", "https://evil.example.net/introspect") {
|
||||
t.Error("introspection on a different host must be rejected")
|
||||
}
|
||||
if sameHost("", "https://kc.example.com") || sameHost("https://kc.example.com", "") {
|
||||
t.Error("empty URL must not be treated as a host match")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank5_OpenRedirectNeutralized verifies the helper the callback now applies
|
||||
// to the stored incoming path forces a host-relative redirect target.
|
||||
func TestRank5_OpenRedirectNeutralized(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"//evil.com/x": "/evil.com/x",
|
||||
`/\evil.com`: "/evil.com",
|
||||
"/legit/path": "/legit/path",
|
||||
}
|
||||
for in, want := range cases {
|
||||
got := normalizeLogoutPath(in)
|
||||
if got != want {
|
||||
t.Errorf("normalizeLogoutPath(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
if strings.HasPrefix(got, "//") || strings.HasPrefix(got, `/\`) {
|
||||
t.Errorf("normalizeLogoutPath(%q) = %q is still protocol-relative", in, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank14_ExcludedURLSegmentBoundary verifies excluded-URL matching is
|
||||
// anchored at path-segment boundaries and cannot be widened into a bypass.
|
||||
func TestRank14_ExcludedURLSegmentBoundary(t *testing.T) {
|
||||
if !pathExcluded("/public", "/public") {
|
||||
t.Error("exact match should be excluded")
|
||||
}
|
||||
if !pathExcluded("/public/page", "/public") {
|
||||
t.Error("sub-path should be excluded")
|
||||
}
|
||||
if pathExcluded("/publicsecret", "/public") {
|
||||
t.Error("/publicsecret must NOT be excluded by /public")
|
||||
}
|
||||
if pathExcluded("/public-admin", "/public") {
|
||||
t.Error("/public-admin must NOT be excluded by /public")
|
||||
}
|
||||
if !pathExcluded("/health", "/health/") {
|
||||
t.Error("trailing-slash config should still match the exact path")
|
||||
}
|
||||
if pathExcluded("/anything", "/") {
|
||||
t.Error("root exclusion must not match arbitrary paths")
|
||||
}
|
||||
if !pathExcluded("/", "/") {
|
||||
t.Error("root exclusion should match the root path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank15_ForwardedHostSanitized verifies a crafted X-Forwarded-Host cannot
|
||||
// inject CRLF, smuggle a second host, or otherwise poison the derived host.
|
||||
func TestRank15_ForwardedHostSanitized(t *testing.T) {
|
||||
mk := func(xfh string) *http.Request {
|
||||
r := httptest.NewRequest(http.MethodGet, "http://real.example.com/x", nil)
|
||||
r.Host = "real.example.com"
|
||||
if xfh != "" {
|
||||
r.Header.Set("X-Forwarded-Host", xfh)
|
||||
}
|
||||
return r
|
||||
}
|
||||
if got := utils.DetermineHost(mk("ext.example.com")); got != "ext.example.com" {
|
||||
t.Errorf("clean X-Forwarded-Host should be honored, got %q", got)
|
||||
}
|
||||
if got := utils.DetermineHost(mk("a.example.com, evil.com")); got != "a.example.com" {
|
||||
t.Errorf("multi-value X-Forwarded-Host should use first host only, got %q", got)
|
||||
}
|
||||
for _, bad := range []string{"evil.com\r\nSet-Cookie: x=1", "evil.com /x", " "} {
|
||||
if got := utils.DetermineHost(mk(bad)); got != "real.example.com" {
|
||||
t.Errorf("malformed X-Forwarded-Host %q should fall back to req.Host, got %q", bad, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank11_TransportPoolTLSIsolationAtLimit verifies that, once the client
|
||||
// limit is reached, the transport pool reuses an existing transport only when
|
||||
// its TLS settings match the caller's, and never hands back a transport built
|
||||
// with different TLS trust settings.
|
||||
func TestRank11_TransportPoolTLSIsolationAtLimit(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
strict := DefaultHTTPClientConfig() // InsecureSkipVerify = false
|
||||
t1 := pool.GetOrCreateTransport(strict)
|
||||
if t1 == nil {
|
||||
t.Fatal("expected a transport for the strict config")
|
||||
}
|
||||
|
||||
// Saturate the client limit so subsequent calls hit the fallback path.
|
||||
atomic.StoreInt32(&pool.clientCount, pool.maxClients)
|
||||
|
||||
// Same TLS settings, different (non-TLS) connection limit: safe to reuse.
|
||||
sameTLS := DefaultHTTPClientConfig()
|
||||
sameTLS.MaxConnsPerHost = 99
|
||||
if got := pool.GetOrCreateTransport(sameTLS); got != t1 {
|
||||
t.Error("at the limit a TLS-compatible config should reuse the existing transport")
|
||||
}
|
||||
|
||||
// Different TLS settings (InsecureSkipVerify): must NOT reuse the strict
|
||||
// transport — returning nil lets the caller fall back to a verifying default.
|
||||
insecure := DefaultHTTPClientConfig()
|
||||
insecure.InsecureSkipVerify = true
|
||||
if got := pool.GetOrCreateTransport(insecure); got == t1 {
|
||||
t.Error("at the limit a config with different TLS settings must not reuse the strict transport")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank9_RedisFingerprint verifies divergent explicit Redis backends produce
|
||||
// distinct fingerprints (used to warn about ignored cache config), while an
|
||||
// absent or disabled Redis yields the empty (no-warning) fingerprint.
|
||||
func TestRank9_RedisFingerprint(t *testing.T) {
|
||||
if redisFingerprint(nil) != "" {
|
||||
t.Error("nil config should yield an empty fingerprint")
|
||||
}
|
||||
if redisFingerprint(&Config{}) != "" {
|
||||
t.Error("config without Redis should yield an empty fingerprint")
|
||||
}
|
||||
if redisFingerprint(&Config{Redis: &RedisConfig{Enabled: false, Address: "a:6379"}}) != "" {
|
||||
t.Error("disabled Redis should yield an empty fingerprint")
|
||||
}
|
||||
a := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "a:6379", KeyPrefix: "p"}})
|
||||
b := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "b:6379", KeyPrefix: "p"}})
|
||||
if a == "" || a == b {
|
||||
t.Errorf("distinct enabled backends must produce distinct non-empty fingerprints (%q vs %q)", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank10_TokenTypeCacheKeyNoCollision verifies that two different tokens
|
||||
// sharing the same 32-character JWT header prefix are classified independently.
|
||||
// The previous 32-char cache key would have collided and mis-classified them.
|
||||
func TestRank10_TokenTypeCacheKeyNoCollision(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
tokenTypeCache: NewCache(),
|
||||
suppressDiagnosticLogs: true,
|
||||
clientID: "client",
|
||||
}
|
||||
// A header prefix longer than 32 chars, shared by both tokens.
|
||||
prefix := "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ"
|
||||
idJWT := &JWT{Header: map[string]interface{}{}, Claims: map[string]interface{}{"nonce": "n"}}
|
||||
accessJWT := &JWT{Header: map[string]interface{}{"typ": "at+jwt"}, Claims: map[string]interface{}{}}
|
||||
|
||||
if !tr.detectTokenType(idJWT, prefix+".id.sig") {
|
||||
t.Error("token with a nonce claim should be detected as an ID token")
|
||||
}
|
||||
if tr.detectTokenType(accessJWT, prefix+".access.sig") {
|
||||
t.Error("access token (typ=at+jwt) must not be mis-classified as ID despite the shared 32-char prefix")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank12_LiveInstanceCounter verifies the process-global instance counter
|
||||
// that gates teardown of shared singleton tasks.
|
||||
func TestRank12_LiveInstanceCounter(t *testing.T) {
|
||||
start := atomic.LoadInt32(&liveInstanceCount)
|
||||
registerLiveInstance()
|
||||
registerLiveInstance()
|
||||
if got := atomic.LoadInt32(&liveInstanceCount); got != start+2 {
|
||||
t.Fatalf("expected %d live instances, got %d", start+2, got)
|
||||
}
|
||||
if rem := unregisterLiveInstance(); rem != start+1 {
|
||||
t.Errorf("expected %d remaining, got %d", start+1, rem)
|
||||
}
|
||||
if rem := unregisterLiveInstance(); rem != start {
|
||||
t.Errorf("expected %d remaining, got %d", start, rem)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank13_CookieMaxAgeMatchesSessionLifetime verifies the cookie store's
|
||||
// MaxAge (which bounds both the cookie Max-Age and the codec's cryptographic
|
||||
// timestamp validity) is bound to the configured session lifetime rather than
|
||||
// gorilla's 30-day default.
|
||||
func TestRank13_CookieMaxAgeMatchesSessionLifetime(t *testing.T) {
|
||||
maxAge := 2 * time.Hour
|
||||
sm, err := NewSessionManager(strings.Repeat("k", 40), false, "", "", maxAge, NewLogger("error"))
|
||||
if err != nil {
|
||||
t.Fatalf("NewSessionManager failed: %v", err)
|
||||
}
|
||||
defer sm.cancel()
|
||||
|
||||
cs, ok := sm.store.(*sessions.CookieStore)
|
||||
if !ok {
|
||||
t.Fatal("session store is not a *sessions.CookieStore")
|
||||
}
|
||||
if got := cs.Options.MaxAge; got != int(maxAge.Seconds()) {
|
||||
t.Errorf("cookie store MaxAge = %d, want %d (bound to sessionMaxAge)", got, int(maxAge.Seconds()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank33And34_HeaderSanitizationDistinction verifies the two header sinks
|
||||
// use the right strictness: free-form templated header VALUES (rank 34) permit
|
||||
// , ; = (e.g. an opaque "Bearer <token>" or an LDAP-DN claim) but reject CR/LF,
|
||||
// bidi, and over-length; claim values joined into delimited/identifier headers
|
||||
// (rank 33) additionally reject , ; =.
|
||||
func TestRank33And34_HeaderSanitizationDistinction(t *testing.T) {
|
||||
// Rank 34 — free-form header value.
|
||||
if headerValueReason("Bearer abc=def==", 8192) != "" {
|
||||
t.Error("'=' must be allowed in a free-form header value (opaque bearer token)")
|
||||
}
|
||||
if headerValueReason("cn=user,ou=eng;dc=x", 8192) != "" {
|
||||
t.Error("',;=' must be allowed in a free-form header value (e.g. an LDAP DN claim)")
|
||||
}
|
||||
if headerValueReason("evil"+string(rune(13))+string(rune(10))+"Injected: 1", 8192) == "" {
|
||||
t.Error("CR/LF must be rejected in a header value (injection)")
|
||||
}
|
||||
if headerValueReason("toolong", 3) == "" {
|
||||
t.Error("over-length value must be rejected")
|
||||
}
|
||||
|
||||
// Rank 33 — claim value bound for a delimited/identifier header.
|
||||
if _, ok := sanitizeHeaderClaimValue("admins,superadmins", 256); ok {
|
||||
t.Error("a comma must be rejected in a value joined into a comma-delimited header")
|
||||
}
|
||||
if _, ok := sanitizeHeaderClaimValue("normal-user@example.com", 256); !ok {
|
||||
t.Error("a clean identifier must pass claim sanitization")
|
||||
}
|
||||
if _, ok := sanitizeHeaderClaimValue("evil"+string(rune(13))+string(rune(10))+"X: 1", 256); ok {
|
||||
t.Error("CR/LF must be rejected in a claim value")
|
||||
}
|
||||
}
|
||||
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
|
||||
// Set up the attacker's session with malicious data
|
||||
attackerSession.SetAuthenticated(true)
|
||||
attackerSession.SetEmail("attacker@evil.com")
|
||||
attackerSession.SetUserIdentifier("attacker@evil.com")
|
||||
attackerSession.SetIDToken(ValidIDToken)
|
||||
attackerSession.SetAccessToken(ValidAccessToken)
|
||||
|
||||
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get the email from the session
|
||||
email := session.GetEmail()
|
||||
email := session.GetUserIdentifier()
|
||||
w.Header().Set("X-User-Email", email)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
@@ -1,590 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityEventType categorizes different types of security events
|
||||
// that can occur during OIDC authentication and authorization flows.
|
||||
type SecurityEventType string
|
||||
|
||||
// Security event types for monitoring and alerting
|
||||
const (
|
||||
// AuthFailure indicates a failed authentication attempt
|
||||
AuthFailure SecurityEventType = "authentication_failure"
|
||||
// TokenValidFailure indicates JWT token validation failed
|
||||
TokenValidFailure SecurityEventType = "token_validation_failure"
|
||||
// RateLimitHit indicates rate limiting was triggered
|
||||
RateLimitHit SecurityEventType = "rate_limit_hit"
|
||||
// SuspiciousActivity indicates potentially malicious behavior
|
||||
SuspiciousActivity SecurityEventType = "suspicious_activity"
|
||||
)
|
||||
|
||||
// DefaultSeverity returns the default severity level for each security event type.
|
||||
// Severity levels are: low, medium, high.
|
||||
func (t SecurityEventType) DefaultSeverity() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "medium"
|
||||
case TokenValidFailure:
|
||||
return "medium"
|
||||
case RateLimitHit:
|
||||
return "low"
|
||||
case SuspiciousActivity:
|
||||
return "high"
|
||||
default:
|
||||
return "medium"
|
||||
}
|
||||
}
|
||||
|
||||
// IPFailureType returns a string identifier for categorizing failures
|
||||
// by IP address for rate limiting and blocking decisions.
|
||||
func (t SecurityEventType) IPFailureType() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "auth_failure"
|
||||
case TokenValidFailure:
|
||||
return "token_failure"
|
||||
case SuspiciousActivity:
|
||||
return "suspicious"
|
||||
default:
|
||||
return "general"
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityEvent represents a security-related event with comprehensive context.
|
||||
// Contains timing information, IP address, user agent, request details,
|
||||
// and custom event-specific data for security analysis and alerting.
|
||||
type SecurityEvent struct {
|
||||
// Timestamp when the event occurred
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
// Details contains event-specific additional information
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
// Type categorizes the event (auth_failure, token_failure, etc.)
|
||||
Type string `json:"type"`
|
||||
// Severity indicates event importance (low, medium, high)
|
||||
Severity string `json:"severity"`
|
||||
// ClientIP is the source IP address of the request
|
||||
ClientIP string `json:"client_ip"`
|
||||
// UserAgent is the User-Agent header from the request
|
||||
UserAgent string `json:"user_agent"`
|
||||
// RequestPath is the requested URL path
|
||||
RequestPath string `json:"request_path"`
|
||||
// Message provides human-readable description of the event
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SecurityMonitor provides comprehensive security monitoring for the OIDC middleware.
|
||||
// It tracks failures by IP address, detects suspicious patterns, enforces
|
||||
// rate limits, and can trigger custom security event handlers.
|
||||
type SecurityMonitor struct {
|
||||
ipFailures map[string]*IPFailureTracker
|
||||
patternDetector *SuspiciousPatternDetector
|
||||
logger *Logger
|
||||
cleanupTask *BackgroundTask
|
||||
eventHandlers []SecurityEventHandler
|
||||
config SecurityMonitorConfig
|
||||
ipMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// IPFailureTracker maintains failure statistics and blocking state for an IP address.
|
||||
// Used for implementing progressive penalties and automatic IP blocking based on
|
||||
// failure patterns, with support for different failure types for
|
||||
// rate limiting and IP blocking decisions.
|
||||
type IPFailureTracker struct {
|
||||
// LastFailure timestamp of the most recent failure
|
||||
LastFailure time.Time
|
||||
// FirstFailure timestamp of the first failure in current window
|
||||
FirstFailure time.Time
|
||||
// BlockedUntil indicates when the IP block expires
|
||||
BlockedUntil time.Time
|
||||
// FailureTypes tracks counts by failure type
|
||||
FailureTypes map[string]int64
|
||||
// FailureCount total number of failures
|
||||
FailureCount int64
|
||||
// mutex protects concurrent access to tracker data
|
||||
mutex sync.RWMutex
|
||||
// IsBlocked indicates if this IP is currently blocked
|
||||
IsBlocked bool
|
||||
}
|
||||
|
||||
// SuspiciousPatternDetector identifies attack patterns that may indicate coordinated threats.
|
||||
// Analyzes events across multiple time windows to detect rapid failures, distributed attacks,
|
||||
// and persistent attack patterns that individual IP monitoring might miss.
|
||||
type SuspiciousPatternDetector struct {
|
||||
// recentEvents stores recent security events for analysis
|
||||
recentEvents []SecurityEvent
|
||||
// shortWindow defines time frame for rapid failure detection
|
||||
shortWindow time.Duration
|
||||
// mediumWindow defines time frame for distributed attack detection
|
||||
mediumWindow time.Duration
|
||||
// longWindow defines time frame for persistent attack detection
|
||||
longWindow time.Duration
|
||||
// rapidFailureThreshold triggers rapid failure alerts
|
||||
rapidFailureThreshold int
|
||||
// distributedAttackThreshold triggers distributed attack alerts
|
||||
distributedAttackThreshold int
|
||||
// persistentAttackThreshold triggers persistent attack alerts
|
||||
persistentAttackThreshold int
|
||||
// eventsMutex protects concurrent access to events
|
||||
eventsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// SecurityEventHandler defines the interface for processing security events.
|
||||
// Implementations can log events, send alerts, update external systems,
|
||||
// or trigger automated response actions.
|
||||
type SecurityEventHandler interface {
|
||||
// HandleSecurityEvent processes a security event
|
||||
HandleSecurityEvent(event SecurityEvent)
|
||||
}
|
||||
|
||||
// SecurityMonitorConfig contains configuration parameters for the security monitor.
|
||||
// Controls thresholds, time windows, and behavior for security monitoring.
|
||||
type SecurityMonitorConfig struct {
|
||||
// MaxFailuresPerIP sets the failure threshold before blocking
|
||||
MaxFailuresPerIP int `json:"max_failures_per_ip"`
|
||||
// FailureWindowMinutes defines the time window for counting failures
|
||||
FailureWindowMinutes int `json:"failure_window_minutes"`
|
||||
// BlockDurationMinutes sets how long to block an IP
|
||||
BlockDurationMinutes int `json:"block_duration_minutes"`
|
||||
// RapidFailureThreshold triggers rapid failure detection
|
||||
RapidFailureThreshold int `json:"rapid_failure_threshold"`
|
||||
// CleanupIntervalMinutes sets cleanup frequency for old data
|
||||
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
|
||||
RetentionHours int `json:"retention_hours"`
|
||||
EnablePatternDetection bool `json:"enable_pattern_detection"`
|
||||
EnableDetailedLogging bool `json:"enable_detailed_logging"`
|
||||
LogSuspiciousOnly bool `json:"log_suspicious_only"`
|
||||
}
|
||||
|
||||
// DefaultSecurityMonitorConfig returns a default configuration
|
||||
func DefaultSecurityMonitorConfig() SecurityMonitorConfig {
|
||||
return SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 10,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 60,
|
||||
EnablePatternDetection: true,
|
||||
RapidFailureThreshold: 5,
|
||||
EnableDetailedLogging: true,
|
||||
LogSuspiciousOnly: false,
|
||||
CleanupIntervalMinutes: 30,
|
||||
RetentionHours: 24,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSecurityMonitor creates a new security monitor instance
|
||||
func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor {
|
||||
sm := &SecurityMonitor{
|
||||
ipFailures: make(map[string]*IPFailureTracker),
|
||||
eventHandlers: make([]SecurityEventHandler, 0),
|
||||
config: config,
|
||||
logger: logger,
|
||||
patternDetector: NewSuspiciousPatternDetector(),
|
||||
}
|
||||
|
||||
sm.startCleanupRoutine()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// NewSuspiciousPatternDetector creates a new pattern detector
|
||||
func NewSuspiciousPatternDetector() *SuspiciousPatternDetector {
|
||||
return &SuspiciousPatternDetector{
|
||||
shortWindow: 1 * time.Minute,
|
||||
mediumWindow: 5 * time.Minute,
|
||||
longWindow: 15 * time.Minute,
|
||||
rapidFailureThreshold: 5,
|
||||
distributedAttackThreshold: 20,
|
||||
persistentAttackThreshold: 50,
|
||||
recentEvents: make([]SecurityEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSecurityEvent is a generic method to record any type of security event
|
||||
func (sm *SecurityMonitor) RecordSecurityEvent(
|
||||
eventType SecurityEventType,
|
||||
clientIP, userAgent, requestPath string,
|
||||
message string,
|
||||
details map[string]interface{},
|
||||
trackIPFailure bool) {
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: string(eventType),
|
||||
Severity: eventType.DefaultSeverity(),
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
RequestPath: requestPath,
|
||||
Message: message,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
if trackIPFailure {
|
||||
sm.recordIPFailure(clientIP, eventType.IPFailureType())
|
||||
}
|
||||
|
||||
sm.processSecurityEvent(event)
|
||||
}
|
||||
|
||||
// RecordAuthenticationFailure records an authentication failure event
|
||||
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
|
||||
if details == nil {
|
||||
details = make(map[string]interface{})
|
||||
}
|
||||
details["reason"] = reason
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
AuthFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Authentication failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordTokenValidationFailure records a token validation failure
|
||||
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
|
||||
details := map[string]interface{}{
|
||||
"reason": reason,
|
||||
}
|
||||
if tokenPrefix != "" {
|
||||
details["token_prefix"] = tokenPrefix
|
||||
}
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
TokenValidFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Token validation failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordRateLimitHit records when rate limiting is triggered
|
||||
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
|
||||
details := map[string]interface{}{
|
||||
"limit_type": "token_verification",
|
||||
}
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
RateLimitHit,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
"Rate limit exceeded",
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
|
||||
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) {
|
||||
if details == nil {
|
||||
details = make(map[string]interface{})
|
||||
}
|
||||
details["activity_type"] = activityType
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
SuspiciousActivity,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// recordIPFailure tracks failures for a specific IP address
|
||||
func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) {
|
||||
sm.ipMutex.Lock()
|
||||
defer sm.ipMutex.Unlock()
|
||||
|
||||
tracker, exists := sm.ipFailures[clientIP]
|
||||
if !exists {
|
||||
tracker = &IPFailureTracker{
|
||||
FailureTypes: make(map[string]int64),
|
||||
FirstFailure: time.Now(),
|
||||
}
|
||||
sm.ipFailures[clientIP] = tracker
|
||||
}
|
||||
|
||||
tracker.mutex.Lock()
|
||||
defer tracker.mutex.Unlock()
|
||||
|
||||
tracker.FailureCount++
|
||||
tracker.LastFailure = time.Now()
|
||||
tracker.FailureTypes[failureType]++
|
||||
|
||||
windowStart := time.Now().Add(-time.Duration(sm.config.FailureWindowMinutes) * time.Minute)
|
||||
if tracker.FirstFailure.After(windowStart) && tracker.FailureCount >= int64(sm.config.MaxFailuresPerIP) {
|
||||
if !tracker.IsBlocked {
|
||||
tracker.IsBlocked = true
|
||||
tracker.BlockedUntil = time.Now().Add(time.Duration(sm.config.BlockDurationMinutes) * time.Minute)
|
||||
|
||||
sm.logger.Errorf("IP %s blocked due to %d failures (types: %v)", clientIP, tracker.FailureCount, tracker.FailureTypes)
|
||||
|
||||
blockEvent := SecurityEvent{
|
||||
Type: "ip_blocked",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes),
|
||||
Details: map[string]interface{}{
|
||||
"failure_count": tracker.FailureCount,
|
||||
"failure_types": tracker.FailureTypes,
|
||||
"blocked_until": tracker.BlockedUntil,
|
||||
},
|
||||
}
|
||||
sm.processSecurityEvent(blockEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsIPBlocked checks if an IP address is currently blocked
|
||||
func (sm *SecurityMonitor) IsIPBlocked(clientIP string) bool {
|
||||
sm.ipMutex.RLock()
|
||||
defer sm.ipMutex.RUnlock()
|
||||
|
||||
tracker, exists := sm.ipFailures[clientIP]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
defer tracker.mutex.RUnlock()
|
||||
|
||||
if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) {
|
||||
return true
|
||||
}
|
||||
|
||||
if tracker.IsBlocked && time.Now().After(tracker.BlockedUntil) {
|
||||
tracker.IsBlocked = false
|
||||
sm.logger.Infof("IP %s automatically unblocked", clientIP)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// processSecurityEvent processes a security event through all handlers and pattern detection
|
||||
func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) {
|
||||
if sm.config.EnablePatternDetection {
|
||||
sm.patternDetector.AddEvent(event)
|
||||
|
||||
if patterns := sm.patternDetector.DetectSuspiciousPatterns(); len(patterns) > 0 {
|
||||
if len(patterns) == 1 {
|
||||
sm.logger.Errorf("Suspicious pattern detected: %s", patterns[0])
|
||||
} else {
|
||||
sm.logger.Errorf("Multiple suspicious patterns detected: %v", patterns)
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
patternEvent := SecurityEvent{
|
||||
Type: "suspicious_pattern",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern),
|
||||
Details: map[string]interface{}{
|
||||
"pattern_type": pattern,
|
||||
"trigger_event": event,
|
||||
},
|
||||
}
|
||||
sm.handleSecurityEvent(patternEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sm.handleSecurityEvent(event)
|
||||
}
|
||||
|
||||
// handleSecurityEvent sends the event to all registered handlers
|
||||
func (sm *SecurityMonitor) handleSecurityEvent(event SecurityEvent) {
|
||||
if sm.config.EnableDetailedLogging && (!sm.config.LogSuspiciousOnly || event.Severity == "high") {
|
||||
sm.logger.Infof("Security Event [%s/%s]: %s (IP: %s, Path: %s)",
|
||||
event.Type, event.Severity, event.Message, event.ClientIP, event.RequestPath)
|
||||
}
|
||||
|
||||
for _, handler := range sm.eventHandlers {
|
||||
go handler.HandleSecurityEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// AddEventHandler adds a security event handler
|
||||
func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
|
||||
sm.eventHandlers = append(sm.eventHandlers, handler)
|
||||
}
|
||||
|
||||
// This is kept for API compatibility but doesn't collect actual metrics
|
||||
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"tracked_ips": 0,
|
||||
}
|
||||
}
|
||||
|
||||
// AddEvent adds an event to the pattern detector
|
||||
func (spd *SuspiciousPatternDetector) AddEvent(event SecurityEvent) {
|
||||
spd.eventsMutex.Lock()
|
||||
defer spd.eventsMutex.Unlock()
|
||||
|
||||
spd.recentEvents = append(spd.recentEvents, event)
|
||||
|
||||
cutoff := time.Now().Add(-spd.longWindow)
|
||||
var filteredEvents []SecurityEvent
|
||||
for _, e := range spd.recentEvents {
|
||||
if e.Timestamp.After(cutoff) {
|
||||
filteredEvents = append(filteredEvents, e)
|
||||
}
|
||||
}
|
||||
spd.recentEvents = filteredEvents
|
||||
}
|
||||
|
||||
// DetectSuspiciousPatterns analyzes recent events for suspicious patterns
|
||||
func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string {
|
||||
spd.eventsMutex.RLock()
|
||||
defer spd.eventsMutex.RUnlock()
|
||||
|
||||
var patterns []string
|
||||
now := time.Now()
|
||||
|
||||
ipCounts := make(map[string]int)
|
||||
shortWindowStart := now.Add(-spd.shortWindow)
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(shortWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
ipCounts[event.ClientIP]++
|
||||
}
|
||||
}
|
||||
|
||||
for ip, count := range ipCounts {
|
||||
if count >= spd.rapidFailureThreshold {
|
||||
patterns = append(patterns, fmt.Sprintf("rapid_failures_from_ip_%s", ip))
|
||||
}
|
||||
}
|
||||
|
||||
mediumWindowStart := now.Add(-spd.mediumWindow)
|
||||
uniqueFailingIPs := make(map[string]bool)
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(mediumWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
uniqueFailingIPs[event.ClientIP] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(uniqueFailingIPs) >= spd.distributedAttackThreshold {
|
||||
patterns = append(patterns, "distributed_attack_pattern")
|
||||
}
|
||||
|
||||
longWindowStart := now.Add(-spd.longWindow)
|
||||
persistentFailures := 0
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(longWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
persistentFailures++
|
||||
}
|
||||
}
|
||||
|
||||
if persistentFailures >= spd.persistentAttackThreshold {
|
||||
patterns = append(patterns, "persistent_attack_pattern")
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
// startCleanupRoutine starts the background cleanup routine
|
||||
func (sm *SecurityMonitor) startCleanupRoutine() {
|
||||
sm.cleanupTask = NewBackgroundTask(
|
||||
"security-monitor-cleanup",
|
||||
time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute,
|
||||
sm.cleanup,
|
||||
sm.logger)
|
||||
sm.cleanupTask.Start()
|
||||
}
|
||||
|
||||
// StopCleanupRoutine stops the background cleanup routine
|
||||
func (sm *SecurityMonitor) StopCleanupRoutine() {
|
||||
if sm.cleanupTask != nil {
|
||||
sm.cleanupTask.Stop()
|
||||
sm.cleanupTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes old tracking data
|
||||
func (sm *SecurityMonitor) cleanup() {
|
||||
sm.ipMutex.Lock()
|
||||
defer sm.ipMutex.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-time.Duration(sm.config.RetentionHours) * time.Hour)
|
||||
|
||||
for ip, tracker := range sm.ipFailures {
|
||||
tracker.mutex.RLock()
|
||||
shouldRemove := tracker.LastFailure.Before(cutoff) && !tracker.IsBlocked
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
if shouldRemove {
|
||||
delete(sm.ipFailures, ip)
|
||||
}
|
||||
}
|
||||
|
||||
sm.logger.Debugf("Security monitor cleanup completed, tracking %d IPs", len(sm.ipFailures))
|
||||
}
|
||||
|
||||
// ExtractClientIP extracts the client IP from the request, considering proxy headers
|
||||
func ExtractClientIP(r *http.Request) string {
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
if net.ParseIP(xri) != nil {
|
||||
return xri
|
||||
}
|
||||
}
|
||||
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
ip := strings.TrimSpace(ips[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// LoggingSecurityEventHandler logs security events to the standard logger
|
||||
type LoggingSecurityEventHandler struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewLoggingSecurityEventHandler creates a new logging event handler
|
||||
func NewLoggingSecurityEventHandler(logger *Logger) *LoggingSecurityEventHandler {
|
||||
return &LoggingSecurityEventHandler{logger: logger}
|
||||
}
|
||||
|
||||
// HandleSecurityEvent implements SecurityEventHandler
|
||||
func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
switch event.Severity {
|
||||
case "high":
|
||||
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
case "medium":
|
||||
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
case "low":
|
||||
h.logger.Infof("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
default:
|
||||
h.logger.Debugf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user