mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
security: remediate audit findings (ranks 1–16 + 22 Lows) + yaegi load validation (#144)
* fix(security): encrypt session cookies + fail closed on invalid config
Batch 1 of security audit remediation (ranks 1, 2, 6).
- session.go: derive independent HMAC + AES-256 keys via stdlib HKDF-SHA256
and build the gorilla cookie store with both, so session cookies are now
encrypted, not merely signed. The single-key store previously left OIDC
access/refresh/ID tokens recoverable from raw cookie bytes. Cookie format
changes, so existing sessions are invalidated on deploy (one-time re-login).
- main.go: call config.Validate() at construction and error out on failure,
instead of silently substituting a public hardcoded encryption key for
empty/short keys (which allowed session forgery). The yaegi analyzer
passes via .traefik.yml testData.
- settings.go: isValidSecureURL permits plaintext HTTP for loopback hosts
only (RFC 8252); remote providers must still use HTTPS.
- tests: complete configs that did not satisfy Validate(); add regression
tests in security_audit_fixes_test.go.
Configs below documented minimums (rateLimit < 10, key < 32 chars) are now
rejected at startup (fail closed).
* fix(security): validate discovered OIDC endpoints + pin introspection host
Batch 2 of security audit remediation (ranks 3, 4).
- url_helpers.go: add validateDiscoveredEndpoint, an SSRF screen for endpoints
taken from the provider discovery document (jwks_uri, token, authorization,
revocation, end_session, introspection, registration). Blocks link-local
(cloud metadata 169.254.169.254), multicast, unspecified and private
addresses (unless allowPrivateIPAddresses); blocks loopback unless the
configured providerURL is itself loopback (dev/test). Cross-domain JWKS
hosts (e.g. Google) stay allowed. Add sameHost helper.
- main.go: updateMetadataEndpoints screens every discovered endpoint and
blanks any that fail (fail closed downstream). The introspection endpoint
carries the client secret via HTTP Basic, so it is additionally pinned to
the providerURL host to stop a poisoned discovery document exfiltrating the
secret to an attacker-controlled host.
- tests: regression tests for the SSRF guard and the host pin.
* fix(security): close open redirects + anchor excluded-URL matching
Batch 3 of security audit remediation (ranks 5, 14, 15).
- auth_flow.go: run the stored incoming path through normalizeLogoutPath
before using it as the post-login redirect, so //evil.com and /\evil.com
payloads become host-relative (open-redirect, rank 5).
- url_helpers.go: excluded-URL matching is anchored at a natural boundary
(exact, sub-path "/", or file extension "."), so excluding "/public" no
longer also bypasses auth on "/publicsecret"; "/favicon" still matches
"/favicon.ico" (rank 14).
- internal/utils: X-Forwarded-Host is sanitized (first value only; reject
CRLF/whitespace/multi-value) before building redirect URLs (rank 15).
- helpers.go: the logout redirect used when there is no provider end-session
endpoint is host-relative, never an absolute URL derived from the
client-controllable request host (logout open-redirect, rank 15).
- tests: update two logout cases that asserted the old absolute redirect;
add regression tests.
* fix(security): reject unverified Azure tokens; fix transport TLS reuse
Batch 4 of security audit remediation (ranks 7, 11).
- token_validation_rs.go: an Azure nonce-bearing access token that cannot be
cryptographically verified no longer returns "authenticated" when there is
no ID token to corroborate it; it refreshes (if possible) or forces
re-authentication instead of failing open (rank 7).
- http_client_pool.go: the at-limit transport-reuse path now takes the write
lock before mutating refCount (fixes a data race) and only reuses a
transport whose TLS settings (CA pool + InsecureSkipVerify) match the
caller's, never one with a different trust store; if none matches it returns
nil so the caller falls back to a verifying default transport (rank 11).
- tests: add a transport-pool TLS-isolation regression test.
* fix(security): stop logging templated header values (token leak)
Batch 5 of security audit remediation (rank 16).
middleware.go: templated downstream headers commonly carry the access token
(e.g. "Authorization: Bearer {{.AccessToken}}"). The debug log line printed
the full header value, leaking credentials into logs. Log the header name and
byte length instead.
* fix(security): cache-key collision, cache-config divergence, fleet cleanup
Batch 6 of security audit remediation (ranks 9, 10, 12).
- token_manager.go: detectTokenType keys its cache on a SHA-256 hash of the
full token instead of the first 32 chars (which are only the base64url JWT
header). Distinct tokens sharing alg+kid no longer collide and get
mis-classified (rank 10).
- cache_manager.go: the process-global cache manager is initialized once and
shared across plugin instances; it now logs a loud warning when a later
instance requests a different explicit Redis backend that is silently
ignored, surfacing the cross-instance state-isolation hazard (rank 9).
- singleton_resources.go / main.go / utilities.go: track a process-global live
instance count; the shared singleton-token-cleanup task is stopped only when
the LAST instance shuts down, so one instance's Close() (e.g. a config reload)
no longer kills cleanup for surviving instances (rank 12).
- tests: update TestDetectTokenTypeCaching for the new key; add regression tests.
* fix(security): bound introspection cache + cookie lifetime to config
Batch 7 of security audit remediation (ranks 8, 13).
- token_introspection.go: when requireTokenIntrospection is enabled, cap the
positive introspection-result cache at 30s (instead of 5m) so a token
revoked at the provider stops passing within ~30s, matching the operator's
near-real-time revocation expectation (rank 8).
- session.go: bind the cookie store's MaxAge to the configured sessionMaxAge,
so the cookie codec's cryptographic timestamp validity is no longer fixed at
gorilla's 30-day default; a stolen cookie is valid only for the configured
session lifetime (rank 13).
- tests: add a cookie-lifetime regression test.
* fix(security): low-severity hardening (cache, DoS caps, PKCE, throttle)
Batch 8 of security audit remediation — low severity
(ranks 24, 25, 27, 29, 31, 36, 37, 41, 45, 46, 49).
- universal_cache.go: updateLocalCache updates an existing key in place instead
of orphaning its LRU element and double-counting currentSize/currentMemory
(rank 36 — the only production-reachable bug in this batch).
- jwk.go / metadata_cache.go / token_introspection.go: bound response bodies
with io.LimitReader (1 MiB) to prevent memory exhaustion from a hostile or
buggy provider (ranks 24, 25).
- jwk.go: skip JWKs not usable for signature verification (use != sig, or
key_ops without "verify") when building the key set (rank 49).
- auth_flow.go: fail closed at the callback when PKCE is enabled but the code
verifier is missing, instead of silently dropping it (rank 27).
- utilities.go / main.go: match allowedUserDomains case-insensitively (rank 31).
- bearer_auth.go: a single success no longer wipes an active per-IP penalty;
the counter resets only when no penalty is in effect (rank 29).
- main.go: handle (not discard) the NewSessionManager error (rank 37).
- error_recovery.go: take a write lock in isServiceDegraded (it deletes from a
map); compare retryable-error substrings case-insensitively (ranks 45, 46).
- singleton_resources.go: bind the generic-cache cleanup goroutine to the
resource-manager shutdown channel so it cannot outlive its owner (rank 41).
- tests: update the bearer throttle test to the corrected penalty semantics.
* fix(security): header sanitization, issuer pinning, fail-closed paths
Batch 9 of security audit remediation (ranks 18, 19, 20, 21, 22, 30, 33, 34).
- middleware.go / bearer_auth.go: sanitize claim-derived values on the cookie
auth path before injecting them into downstream headers. Drop group/role and
identifier values containing control chars, bidi-override runes, or the
, ; = delimiters (a comma would inject phantom entries into X-User-Groups);
reject control/bidi/over-length in rendered templated header output (but
permit , ; = in free-form values such as a bearer token). The bearer path
already sanitized; the cookie path did not (ranks 33, 34).
- main.go / metadata_cache.go: pin the discovered issuer to the configured
provider host (sameHost) and refuse/never-cache a mismatch, so a poisoned
discovery document cannot redefine the JWT trust anchor (ranks 21, 22).
- token_introspection.go: when a distinct API audience is configured, fail
closed on a missing or mismatched introspection audience; aud parsed as
string-or-array per RFC 7662 (rank 19).
- logout.go: front-channel logout requires a matching issuer; an empty iss is
rejected (blocks unauthenticated forced-logout via a known sid) (rank 30).
- token_validation_rs.go: an opaque access token with no ID token and no
successful introspection fails closed (re-auth) instead of authenticating
(ranks 18, 20).
- tests: realistic same-host provider mocks; regression tests for the header
sanitization distinction and the fail-closed paths.
* chore(security): remove unwired dead code with latent footguns
Batch 10 of security audit remediation — delete confirmed-dead, unwired
subsystems (ranks 26, 35, 50). None had a production caller (grep-verified);
removal eliminates the latent footguns and ~2.1k lines of dead code.
- token_validator.go (deleted): an unused *TokenValidator whose validateJWT set
Valid=true with NO signature verification — a severe footgun if ever wired
(rank 50). The wired RS-aware validators are unaffected.
- security_monitoring.go (deleted): an unused *SecurityMonitor / ExtractClientIP
that trusted spoofable X-Forwarded-For / X-Real-IP. The live bearer throttle
uses clientIPForBearer (RemoteAddr-only), unchanged (rank 35).
- dynamic_client_registration.go: removed the RFC 7592 management methods
(Update/Read/DeleteClientRegistration) that dereferenced an attacker-
influenced RegistrationClientURI with the registration token attached and no
HTTPS/SSRF gate, and had no callers. The wired RFC 7591 RegisterClient and
credential-store helpers are kept (rank 26).
- tests: removed the tests covering the deleted code.
* chore: add Makefile with yaegi load validation
No Makefile existed. The new `yaegi-validate` target interprets the plugin
under the yaegi interpreter the same way Traefik loads it, catching yaegi-only
incompatibilities (unsupported stdlib symbols, reflection edge cases) that the
native `go build` / `go test` toolchain does not. Importing the plugin forces
yaegi to interpret every file plus its vendored deps; CreateConfig + New
exercise the instantiation path.
- cmd/yaegicheck/main.go: the load driver, marked //go:build ignore so it is
excluded from `go build ./...` (avoids VCS-stamping a main binary, which
fails in git-worktree layouts) yet is run explicitly by yaegi.
- Makefile: build / fmt / vet / lint / test / vendor / yaegi-validate / check
targets; `make check` runs vet + tests + yaegi-validate.
Verified: `make yaegi-validate` passes on this branch — the HKDF cookie
encryption, net-based endpoint validation, and claim sanitizers all interpret
and instantiate cleanly under yaegi.
* ci: bump workflow Go toolchain to 1.25; pin yaegi-validate to v0.16.1
Traefik v3.7.1 (the deployed version) is built with `go 1.25.0`, so the PR and
release workflows now use Go 1.25.x to match the toolchain Traefik uses.
Important distinction: the CI Go version is the build TOOLCHAIN. The plugin's
actual interpreter-compatibility ceiling is the yaegi version Traefik bundles
(v0.16.1, which declares go 1.21 and ships a ~Go 1.22 stdlib symbol surface),
NOT the CI Go version. That ceiling is enforced by `make yaegi-validate` plus
the go.mod language directive — e.g. it is why HKDF is hand-rolled with
hmac+sha256 rather than Go 1.24's crypto/hkdf, which yaegi v0.16.1 lacks.
Also pin Makefile YAEGI_VERSION to v0.16.1 (what Traefik v3.7.1 vendors) so
yaegi-validate exercises the real deployed interpreter instead of @latest,
which could pass on a newer yaegi that supports symbols the deployed one does
not.
* docs: align README/CONFIGURATION with branch behavior changes
- excludedURLs: documented as segment/extension-boundary matching (was
"prefix-matched") — "/public" no longer also matches "/publicsecret" (rank 14).
- Front-channel logout now requires a matching `iss`; requests without one are
rejected with 400 (rank 30).
- Add an "Upgrading from an earlier release" note: session cookies are now
AES-256 encrypted with lifetime tracking sessionMaxAge (one-time re-login on
upgrade), and invalid configuration (rateLimit < 10, key < 32 bytes, missing
callbackURL, non-HTTPS remote providerURL) now fails closed at startup.
* fix: remove staticcheck-flagged unused functions; wire staticcheck into make check
CI Static Analysis (standalone staticcheck) failed with U1000 "unused":
- dynamic_client_registration.go: deleteCredentialsFromStore — its only caller
was the RFC 7592 DeleteClientRegistration removed in the dead-code batch.
- token_test.go: createTestJWTSimple — its only callers were the TokenValidator
tests removed in the same batch.
Both confirmed to have zero remaining callers and removed. build / vet /
go test ./... / staticcheck ./... all green.
The pre-commit hook runs golangci-lint, but CI runs standalone staticcheck
(which flags U1000). Add a `staticcheck` Makefile target and include it in
`make check` so this class of finding is caught locally before push.
* fix(test): stabilize flaky TestWorkerPool_TaskPanic
tasksFailed is incremented in the worker's deferred recover(), which runs after the panicking task's own defer wg.Done(). wg.Wait() could therefore return before the failure was recorded, so reading the counter immediately raced and flaked on slow CI runners. Poll until the failure lands (2s budget) instead. Verified 200x plain + 50x under -race/GOMAXPROCS=1.
This commit is contained in:
@@ -18,6 +18,6 @@ jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
go-version: "1.25.x"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
|
||||
@@ -19,5 +19,5 @@ jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
go-version: "1.25.x"
|
||||
secrets: inherit
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# traefikoidc — Makefile
|
||||
# Run `make help` for available targets.
|
||||
|
||||
GO ?= go
|
||||
GOPATH := $(shell $(GO) env GOPATH)
|
||||
# Pin to the yaegi version bundled by the deployed Traefik so yaegi-validate
|
||||
# tests the real interpreter, not a newer one that may support more. Traefik
|
||||
# v3.7.1 vendors yaegi v0.16.1 (Go ~1.22 stdlib surface). Bump when Traefik is.
|
||||
YAEGI_VERSION ?= v0.16.1
|
||||
TEST_TIMEOUT ?= 480s
|
||||
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
.PHONY: help
|
||||
help: ## Show this help
|
||||
@grep -hE '^[a-zA-Z0-9_-]+:.*## ' $(MAKEFILE_LIST) | awk 'BEGIN{FS=":.*## "}{printf " \033[36m%-16s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: build
|
||||
build: ## Compile all packages (native toolchain)
|
||||
$(GO) build ./...
|
||||
|
||||
.PHONY: fmt
|
||||
fmt: ## Format sources with gofmt
|
||||
gofmt -w $$(git ls-files '*.go' | grep -v '^vendor/')
|
||||
|
||||
.PHONY: vet
|
||||
vet: ## Run go vet
|
||||
$(GO) vet ./...
|
||||
|
||||
.PHONY: lint
|
||||
lint: ## Run golangci-lint if available
|
||||
@command -v golangci-lint >/dev/null 2>&1 && golangci-lint run ./... || echo "golangci-lint not installed; skipping"
|
||||
|
||||
.PHONY: staticcheck
|
||||
staticcheck: ## Run staticcheck (matches the CI "Static Analysis" job; catches U1000 unused, etc.)
|
||||
@command -v staticcheck >/dev/null 2>&1 || { echo ">> installing staticcheck"; $(GO) install honnef.co/go/tools/cmd/staticcheck@latest; }
|
||||
@GOFLAGS=-buildvcs=false $$(command -v staticcheck || echo "$(GOPATH)/bin/staticcheck") ./...
|
||||
|
||||
.PHONY: test
|
||||
test: ## Run the test suite
|
||||
$(GO) test ./... -count=1 -timeout $(TEST_TIMEOUT)
|
||||
|
||||
.PHONY: vendor
|
||||
vendor: ## Refresh and vendor dependencies
|
||||
$(GO) mod tidy && $(GO) mod vendor
|
||||
|
||||
# yaegi-validate interprets the plugin under the yaegi interpreter the same way
|
||||
# Traefik loads it. Native `go build`/`go test` use the standard compiler and do
|
||||
# NOT catch yaegi-only incompatibilities (unsupported stdlib symbols, reflection
|
||||
# edge cases). This target does. Importing the package forces yaegi to interpret
|
||||
# every file in it plus its vendored deps; CreateConfig + New exercise the
|
||||
# instantiation path. Pin YAEGI_VERSION to match Traefik's bundled yaegi if you
|
||||
# need exact parity.
|
||||
.PHONY: yaegi-validate
|
||||
yaegi-validate: ## Verify the plugin loads under Traefik's yaegi interpreter
|
||||
@command -v yaegi >/dev/null 2>&1 || { echo ">> installing yaegi@$(YAEGI_VERSION)"; $(GO) install github.com/traefik/yaegi/cmd/yaegi@$(YAEGI_VERSION); }
|
||||
@echo ">> interpreting plugin under yaegi (as Traefik does)"
|
||||
@DO_NOT_TRACK=1 GOFLAGS=-mod=vendor $$(command -v yaegi || echo "$(GOPATH)/bin/yaegi") run ./cmd/yaegicheck/main.go
|
||||
|
||||
.PHONY: check
|
||||
check: vet staticcheck test yaegi-validate ## vet + staticcheck + tests + yaegi load validation
|
||||
@@ -112,7 +112,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `postLogoutRedirectURI` | `/` | Where to send users after logout. |
|
||||
| `scopes` | appended to `openid profile email` | Extra OAuth scopes. Set `overrideScopes: true` to replace defaults. |
|
||||
| `extraAuthParams` | none | Map of extra query parameters appended to the authorization request (e.g. `screen_hint: signup`, `login_hint`, `ui_locales`, `prompt`). Plugin-managed params (`client_id`, `state`, `nonce`, `redirect_uri`, `code_challenge`, `scope`, `response_type`, …) cannot be overridden. |
|
||||
| `excludedURLs` | none | Prefix-matched paths that bypass auth. |
|
||||
| `excludedURLs` | none | Paths that bypass auth, matched at a path-segment or file-extension boundary (e.g. `/public` matches `/public`, `/public/sub` and `/public.json`, but **not** `/publicsecret`). |
|
||||
| `allowedUserDomains` | none | Restrict to email domains. |
|
||||
| `allowedUsers` | none | Restrict to specific addresses (or claim values when `userIdentifierClaim != email`). |
|
||||
| `allowedRolesAndGroups` | none | Require any of these roles/groups from ID-token claims. |
|
||||
@@ -147,6 +147,18 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
|
||||
## Production gotchas
|
||||
|
||||
### Upgrading from an earlier release
|
||||
|
||||
- **Sessions are re-issued once.** Session cookies are now AES-256 encrypted
|
||||
(previously signed only) and their cryptographic lifetime tracks
|
||||
`sessionMaxAge` (previously a fixed 30 days). Existing cookies become invalid
|
||||
on upgrade, so users re-authenticate one time.
|
||||
- **Invalid configuration now fails closed at startup** instead of being
|
||||
silently accepted: a `sessionEncryptionKey` shorter than 32 bytes, a
|
||||
`rateLimit` below 10, a missing `callbackURL`, or a non-HTTPS remote
|
||||
`providerURL` are rejected. Plaintext HTTP is permitted only for loopback
|
||||
hosts (local development).
|
||||
|
||||
### TLS termination at a load balancer
|
||||
|
||||
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
|
||||
@@ -166,6 +178,8 @@ detected" when the same token hits different replicas. Two options:
|
||||
|
||||
For IdP-initiated logout (back/front-channel) in multi-replica setups, Redis is
|
||||
**required** so a logout on one instance invalidates sessions on the others.
|
||||
Front-channel logout requests must include a matching `iss` query parameter;
|
||||
requests that omit it are rejected with `400`.
|
||||
|
||||
### Multiple middleware instances on the same host
|
||||
|
||||
|
||||
+9
-1
@@ -182,6 +182,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
}
|
||||
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
if t.enablePKCE && codeVerifier == "" {
|
||||
t.logger.Error("PKCE is enabled but code verifier is missing from session during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: PKCE verifier missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
@@ -263,7 +268,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
// Neutralize open-redirect payloads (e.g. //evil.com, /\evil.com) stored
|
||||
// from the original request target before using it as the post-login
|
||||
// redirect target. normalizeLogoutPath forces a host-relative path.
|
||||
redirectPath = normalizeLogoutPath(incomingPath)
|
||||
}
|
||||
session.SetIncomingPath("")
|
||||
|
||||
|
||||
+101
-10
@@ -149,6 +149,94 @@ func parseBearerJOSEHeader(token string) *bearerError {
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerClaimRuneReason reports why a rune is unsafe to inject into a request
|
||||
// header value, or "" if the rune is acceptable. Shared core of the bearer-path
|
||||
// identifier sanitizer and the cookie-path header claim sanitizer: rejects
|
||||
// control chars (CRLF/header injection), Unicode bidi-override runes (RTL
|
||||
// spoofing of admin UI / SIEM), and the delimiters , ; = (a comma in a group
|
||||
// name would inject extra entries into a comma-joined header).
|
||||
func headerClaimRuneReason(r rune) string {
|
||||
if reason := headerInjectionRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
// The , ; = delimiters are only unsafe for values placed into delimited or
|
||||
// list contexts (a comma-joined header, or an identifier downstreams may
|
||||
// split). They are valid in arbitrary single header values, so this stricter
|
||||
// check is used for the cookie-path identifier and the group/role list, NOT
|
||||
// for free-form templated header output (see headerValueReason).
|
||||
if r == ',' || r == ';' || r == '=' {
|
||||
return "delimiter character"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerInjectionRuneReason reports why a rune is unsafe in ANY HTTP header
|
||||
// value, or "" if acceptable. Rejects control characters (CR/LF header
|
||||
// injection) and Unicode bidi-override runes (RTL spoofing of admin UIs/SIEMs).
|
||||
// Unlike headerClaimRuneReason it does NOT reject , ; = which are legitimate in
|
||||
// free-form header values (e.g. an opaque "Authorization: Bearer <token>").
|
||||
func headerInjectionRuneReason(r rune) string {
|
||||
if unicode.IsControl(r) {
|
||||
return "control character"
|
||||
}
|
||||
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
|
||||
return "bidi-override character"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerValueReason reports why value is unsafe to forward as a free-form HTTP
|
||||
// header value, or "" if acceptable. It rejects values over maxLen (maxLen<=0
|
||||
// disables the check) and values containing control or bidi-override runes, but
|
||||
// permits , ; = (valid in header values). Empty is allowed. The reason string
|
||||
// never includes the value, so it is safe to log.
|
||||
func headerValueReason(value string, maxLen int) string {
|
||||
if maxLen > 0 && len(value) > maxLen {
|
||||
return "exceeds max length"
|
||||
}
|
||||
for _, r := range value {
|
||||
if reason := headerInjectionRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerClaimValueReason reports why value is unsafe to inject into a
|
||||
// downstream request header, or "" if it is acceptable. It rejects empty
|
||||
// values, values exceeding maxLen (maxLen<=0 disables the length check), and
|
||||
// values containing any rune rejected by headerClaimRuneReason. The reason
|
||||
// string is safe to log (it never includes the value itself).
|
||||
func headerClaimValueReason(value string, maxLen int) string {
|
||||
if value == "" {
|
||||
return "empty value"
|
||||
}
|
||||
if maxLen > 0 && len(value) > maxLen {
|
||||
return "exceeds max length"
|
||||
}
|
||||
for _, r := range value {
|
||||
if reason := headerClaimRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sanitizeHeaderClaimValue validates a claim-derived value before it is
|
||||
// injected into a downstream request header. It trims surrounding whitespace
|
||||
// and fails closed (ok=false) on empty values, values exceeding maxLen
|
||||
// (maxLen<=0 disables the length check), or values containing any rune rejected
|
||||
// by headerClaimRuneReason. Used by the cookie/session path, which — unlike the
|
||||
// bearer path — does not otherwise sanitize the principal identifier or the
|
||||
// group/role strings joined into X-User-Groups / X-User-Roles.
|
||||
func sanitizeHeaderClaimValue(raw string, maxLen int) (string, bool) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if headerClaimValueReason(value, maxLen) != "" {
|
||||
return "", false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
// sanitizeBearerIdentifier validates and trims a principal identifier before
|
||||
// it is injected into request headers. Layered defense: net/http will reject
|
||||
// CRLF on the wire too, but rejecting early gives clearer error logs and
|
||||
@@ -163,15 +251,8 @@ func sanitizeBearerIdentifier(raw string, maxLen int) (string, *bearerError) {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier exceeds max length")
|
||||
}
|
||||
for _, r := range identifier {
|
||||
if unicode.IsControl(r) {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains control character")
|
||||
}
|
||||
// Unicode bidi-override range (RTL spoofing of admin UI / SIEM).
|
||||
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains bidi-override character")
|
||||
}
|
||||
if r == ',' || r == ';' || r == '=' {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains delimiter character")
|
||||
if reason := headerClaimRuneReason(r); reason != "" {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains "+reason)
|
||||
}
|
||||
}
|
||||
return identifier, nil
|
||||
@@ -342,7 +423,17 @@ func (b *bearerFailureTracker) recordSuccess(ip string) {
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
delete(b.entries, ip)
|
||||
e, ok := b.entries[ip]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Preserve an active penalty so a single success cannot wipe an in-effect
|
||||
// lockout; only reset the counter when no penalty is active or it has expired.
|
||||
now := time.Now()
|
||||
if e.penaltyUntil.IsZero() || now.After(e.penaltyUntil) {
|
||||
e.count = 0
|
||||
e.firstFailureAt = now
|
||||
}
|
||||
}
|
||||
|
||||
// clientIPForBearer returns the source IP used to key the failure tracker.
|
||||
|
||||
+21
-3
@@ -303,15 +303,33 @@ func TestBearerFailureTracker(t *testing.T) {
|
||||
if b, retry := tr.blocked(ip); !b || retry <= 0 {
|
||||
t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry)
|
||||
}
|
||||
// Success clears the counter.
|
||||
// A success while a penalty is active must NOT wipe the in-effect lockout
|
||||
// (otherwise a single success could clear an attacker's penalty).
|
||||
tr.recordSuccess(ip)
|
||||
if b, _ := tr.blocked(ip); b {
|
||||
t.Fatalf("expected unblocked after success")
|
||||
if b, _ := tr.blocked(ip); !b {
|
||||
t.Fatalf("expected still blocked after success while penalty active")
|
||||
}
|
||||
// Other IPs are unaffected.
|
||||
if b, _ := tr.blocked("10.0.0.2"); b {
|
||||
t.Fatalf("unrelated IP should not be blocked")
|
||||
}
|
||||
|
||||
// With an expired penalty, a success resets the counter so a subsequent
|
||||
// sub-threshold failure does not immediately re-block.
|
||||
tr2 := newBearerFailureTracker(3, 60*time.Second, 1*time.Millisecond)
|
||||
const ip2 = "10.0.0.3"
|
||||
for i := 0; i < 3; i++ {
|
||||
tr2.recordFailure(ip2)
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond) // let the short penalty expire
|
||||
if b, _ := tr2.blocked(ip2); b {
|
||||
t.Fatalf("expected unblocked after penalty expiry")
|
||||
}
|
||||
tr2.recordSuccess(ip2) // resets count since penalty has passed
|
||||
tr2.recordFailure(ip2) // single failure, well below threshold
|
||||
if b, _ := tr2.blocked(ip2); b {
|
||||
t.Fatalf("expected unblocked: counter should have reset after success")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
@@ -18,6 +18,7 @@ type CacheManager struct {
|
||||
var (
|
||||
globalCacheManagerInstance *CacheManager
|
||||
cacheManagerInitOnce sync.Once
|
||||
cacheManagerActiveFingerprint string
|
||||
)
|
||||
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance.
|
||||
@@ -29,7 +30,9 @@ func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
|
||||
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
|
||||
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
|
||||
fp := redisFingerprint(config)
|
||||
cacheManagerInitOnce.Do(func() {
|
||||
cacheManagerActiveFingerprint = fp
|
||||
var redisConfig *RedisConfig
|
||||
var logger *Logger
|
||||
|
||||
@@ -55,9 +58,27 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
|
||||
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
|
||||
}
|
||||
})
|
||||
// Warn loudly if a later instance asks for a DIFFERENT explicit Redis
|
||||
// backend than the one that won initialization: the cache manager is a
|
||||
// process-global singleton shared across plugin instances (yaegi), so this
|
||||
// instance's divergent configuration is silently ignored, which would
|
||||
// otherwise collapse cache/state isolation between routes (rank 9).
|
||||
if fp != "" && cacheManagerActiveFingerprint != "" && fp != cacheManagerActiveFingerprint {
|
||||
NewLogger(config.LogLevel).Errorf("cache manager already initialized with Redis backend %q; this instance's Redis backend %q is IGNORED (process-global singleton). Use a single consistent cache configuration across all routes.", cacheManagerActiveFingerprint, fp)
|
||||
}
|
||||
return globalCacheManagerInstance
|
||||
}
|
||||
|
||||
// redisFingerprint returns a stable identifier for an explicitly-enabled Redis
|
||||
// backend (address + key prefix), or "" when Redis is not explicitly enabled.
|
||||
// Used to detect divergent cache configurations across plugin instances.
|
||||
func redisFingerprint(config *Config) string {
|
||||
if config == nil || config.Redis == nil || !config.Redis.Enabled {
|
||||
return ""
|
||||
}
|
||||
return config.Redis.Address + "|" + config.Redis.KeyPrefix
|
||||
}
|
||||
|
||||
// GetSharedTokenBlacklist returns the shared token blacklist cache
|
||||
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
//go:build ignore
|
||||
|
||||
// Command yaegicheck verifies that the traefikoidc plugin can be imported and
|
||||
// instantiated by the yaegi interpreter — the same way Traefik loads a plugin.
|
||||
//
|
||||
// It is run by `make yaegi-validate`. Importing the plugin package forces yaegi
|
||||
// to interpret every source file in the package (and its vendored
|
||||
// dependencies), so any construct yaegi cannot handle (unsupported stdlib
|
||||
// symbol, reflection edge case, etc.) surfaces here rather than at Traefik load
|
||||
// time. CreateConfig + New additionally exercise the instantiation path
|
||||
// (session manager, cookie codec, caches, key derivation) under the interpreter.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
oidc "github.com/lukaszraczylo/traefikoidc"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := oidc.CreateConfig()
|
||||
cfg.ProviderURL = "https://accounts.google.com"
|
||||
cfg.ClientID = "yaegi-check-client"
|
||||
cfg.ClientSecret = "yaegi-check-secret"
|
||||
cfg.CallbackURL = "/oauth2/callback"
|
||||
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef"
|
||||
cfg.RateLimit = 100
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
h, err := oidc.New(context.Background(), next, cfg, "yaegi-check")
|
||||
if err != nil {
|
||||
fmt.Println("FAIL: New returned an error under yaegi:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if h == nil {
|
||||
fmt.Println("FAIL: New returned a nil handler under yaegi")
|
||||
os.Exit(1)
|
||||
}
|
||||
if closer, ok := h.(interface{ Close() error }); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
fmt.Println("OK: traefikoidc imported + CreateConfig + New succeeded under yaegi")
|
||||
}
|
||||
@@ -278,82 +278,6 @@ func TestHTTPClientProfiler_Methods_CoverageBoost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SECURITY MONITORING TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestSecurityMonitor_StopCleanupRoutine_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
config := SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 5,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 30,
|
||||
RapidFailureThreshold: 3,
|
||||
CleanupIntervalMinutes: 60,
|
||||
RetentionHours: 24,
|
||||
EnablePatternDetection: true,
|
||||
EnableDetailedLogging: false,
|
||||
LogSuspiciousOnly: false,
|
||||
}
|
||||
|
||||
sm := NewSecurityMonitor(config, logger)
|
||||
if sm == nil {
|
||||
t.Fatal("Expected non-nil SecurityMonitor")
|
||||
}
|
||||
|
||||
// Start cleanup routine first (lowercase method)
|
||||
sm.startCleanupRoutine()
|
||||
|
||||
// Give it a moment to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop cleanup routine (public method)
|
||||
sm.StopCleanupRoutine()
|
||||
|
||||
// Stop again should be safe
|
||||
sm.StopCleanupRoutine()
|
||||
}
|
||||
|
||||
func TestSecurityMonitor_MultipleHandlers_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
config := SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 5,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 30,
|
||||
RapidFailureThreshold: 3,
|
||||
CleanupIntervalMinutes: 60,
|
||||
RetentionHours: 24,
|
||||
}
|
||||
|
||||
sm := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Create handler
|
||||
handler := &LoggingSecurityEventHandler{logger: logger}
|
||||
|
||||
// Register handler using AddEventHandler
|
||||
sm.AddEventHandler(handler)
|
||||
|
||||
// Record a failure to trigger events
|
||||
sm.RecordAuthenticationFailure("192.168.1.100", "test-agent", "/test", "test_failure", nil)
|
||||
}
|
||||
|
||||
func TestLoggingSecurityEventHandler_HandleSecurityEvent_AllSeverities_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
handler := &LoggingSecurityEventHandler{logger: logger}
|
||||
|
||||
// Severity is a string in this implementation
|
||||
events := []SecurityEvent{
|
||||
{Type: "test", Severity: "low", Message: "low severity"},
|
||||
{Type: "test", Severity: "medium", Message: "medium severity"},
|
||||
{Type: "test", Severity: "high", Message: "high severity"},
|
||||
{Type: "test", Severity: "critical", Message: "critical severity"},
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
handler.HandleSecurityEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SESSION MANAGER TESTS
|
||||
// =============================================================================
|
||||
|
||||
@@ -178,7 +178,7 @@ clientSecret: your-client-secret
|
||||
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
|
||||
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
|
||||
| `rateLimit` | int | `100` | Maximum requests per second |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication, matched at a path-segment or file-extension boundary |
|
||||
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
|
||||
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
|
||||
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
|
||||
|
||||
@@ -370,21 +370,6 @@ func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, res
|
||||
return r.saveCredentials(resp)
|
||||
}
|
||||
|
||||
// deleteCredentialsFromStore removes credentials from the configured storage backend
|
||||
// Falls back to legacy file-based deletion if no store is configured
|
||||
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Delete(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based deletion
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
@@ -423,187 +408,3 @@ func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse,
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateClientRegistration updates an existing client registration using RFC 7592
|
||||
// This requires the registration_client_uri and registration_access_token from the original registration
|
||||
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to update")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Build update request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build update request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create update request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read update response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse update response: %w", err)
|
||||
}
|
||||
|
||||
// Update cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// ReadClientRegistration reads the current client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to read")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create read request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse read response: %w", err)
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// DeleteClientRegistration deletes the client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return fmt.Errorf("no existing registration to delete")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create delete request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle error responses (204 No Content is success)
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials from storage if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.deleteCredentialsFromStore(ctx); err != nil {
|
||||
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("Successfully deleted client registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -735,258 +735,6 @@ func TestDCRConfigDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateClientRegistration tests the RFC 7592 client update functionality
|
||||
func TestUpdateClientRegistration(t *testing.T) {
|
||||
updateCalled := false
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPut {
|
||||
updateCalled = true
|
||||
|
||||
// Verify authorization header
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
t.Error("Missing Authorization header for update")
|
||||
}
|
||||
|
||||
resp := ClientRegistrationResponse{
|
||||
ClientID: "updated-client-id",
|
||||
ClientSecret: "updated-client-secret",
|
||||
RegistrationAccessToken: "new-access-token",
|
||||
RegistrationClientURI: r.URL.String(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
ClientMetadata: &ClientRegistrationMetadata{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
},
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "original-client-id",
|
||||
ClientSecret: "original-client-secret",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Perform update
|
||||
ctx := context.Background()
|
||||
resp, err := registrar.UpdateClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
if !updateCalled {
|
||||
t.Error("Update endpoint was not called")
|
||||
}
|
||||
|
||||
if resp.ClientID != "updated-client-id" {
|
||||
t.Errorf("Updated ClientID mismatch: got %s", resp.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteClientRegistration tests the RFC 7592 client deletion functionality
|
||||
func TestDeleteClientRegistration(t *testing.T) {
|
||||
deleteCalled := false
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodDelete {
|
||||
deleteCalled = true
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
credentialsFile := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
// Create a credentials file to test deletion
|
||||
os.WriteFile(credentialsFile, []byte(`{"client_id":"test"}`), 0600)
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
CredentialsFile: credentialsFile,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Perform delete
|
||||
ctx := context.Background()
|
||||
err := registrar.DeleteClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if !deleteCalled {
|
||||
t.Error("Delete endpoint was not called")
|
||||
}
|
||||
|
||||
// Verify cache is cleared
|
||||
if registrar.GetCachedResponse() != nil {
|
||||
t.Error("Cached response should be cleared after deletion")
|
||||
}
|
||||
|
||||
// Verify credentials file is deleted
|
||||
if _, err := os.Stat(credentialsFile); !os.IsNotExist(err) {
|
||||
t.Error("Credentials file should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadClientRegistration tests the RFC 7592 client read functionality
|
||||
func TestReadClientRegistration(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet {
|
||||
resp := ClientRegistrationResponse{
|
||||
ClientID: "read-client-id",
|
||||
ClientSecret: "read-client-secret",
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ResponseTypes: []string{"code"},
|
||||
GrantTypes: []string{"authorization_code"},
|
||||
ApplicationType: "web",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "original-client-id",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Read registration
|
||||
ctx := context.Background()
|
||||
resp, err := registrar.ReadClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.ClientID != "read-client-id" {
|
||||
t.Errorf("Read ClientID mismatch: got %s", resp.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOperationsWithoutCachedResponse tests error handling when no cached response exists
|
||||
func TestOperationsWithoutCachedResponse(t *testing.T) {
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
&http.Client{},
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
"https://example.com",
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Update without cached response
|
||||
_, err := registrar.UpdateClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Update should fail without cached response: %v", err)
|
||||
}
|
||||
|
||||
// Test Read without cached response
|
||||
_, err = registrar.ReadClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Read should fail without cached response: %v", err)
|
||||
}
|
||||
|
||||
// Test Delete without cached response
|
||||
err = registrar.DeleteClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Delete should fail without cached response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOperationsWithoutManagementCredentials tests error handling without management URIs
|
||||
func TestOperationsWithoutManagementCredentials(t *testing.T) {
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
&http.Client{},
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
"https://example.com",
|
||||
)
|
||||
|
||||
// Set up cached response WITHOUT management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
// Missing RegistrationAccessToken and RegistrationClientURI
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Update without management credentials
|
||||
_, err := registrar.UpdateClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Update should fail without management credentials: %v", err)
|
||||
}
|
||||
|
||||
// Test Read without management credentials
|
||||
_, err = registrar.ReadClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Read should fail without management credentials: %v", err)
|
||||
}
|
||||
|
||||
// Test Delete without management credentials
|
||||
err = registrar.DeleteClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Delete should fail without management credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stringContains is a helper function to check if a string contains a substring
|
||||
func stringContains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContainsHelper(s, substr))
|
||||
|
||||
+6
-5
@@ -539,10 +539,10 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
errStr := strings.ToLower(err.Error())
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
if contains(errStr, retryableErr) {
|
||||
if contains(errStr, strings.ToLower(retryableErr)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -551,7 +551,7 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
errStr := netErr.Error()
|
||||
errStr := strings.ToLower(netErr.Error())
|
||||
temporaryPatterns := []string{
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
@@ -859,8 +859,9 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f
|
||||
|
||||
// isServiceDegraded checks if a service is currently degraded
|
||||
func (gd *GracefulDegradation) isServiceDegraded(serviceName string) bool {
|
||||
gd.mutex.RLock()
|
||||
defer gd.mutex.RUnlock()
|
||||
// Uses a write lock because the recovery-timeout branch deletes from the map.
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
|
||||
degradedTime, exists := gd.degradedServices[serviceName]
|
||||
if !exists {
|
||||
|
||||
+10
-1
@@ -392,10 +392,19 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
// localRedirect is used when there is no provider end-session endpoint and
|
||||
// the plugin redirects the browser itself. It must never be an absolute URL
|
||||
// derived from the request host (X-Forwarded-Host is client-controllable and
|
||||
// would be an open redirect); use a host-relative path, or the operator's
|
||||
// own configured absolute URL, instead.
|
||||
localRedirect := "/"
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
||||
localRedirect = normalizeLogoutPath(postLogoutRedirectURI)
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
} else {
|
||||
localRedirect = postLogoutRedirectURI
|
||||
}
|
||||
|
||||
// Read endSessionURL with RLock
|
||||
@@ -414,7 +423,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
http.Redirect(rw, req, localRedirect, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs a logout URL for the OIDC provider's end session endpoint.
|
||||
|
||||
+30
-6
@@ -26,6 +26,10 @@ type sharedTransport struct {
|
||||
lastUsed time.Time
|
||||
transport *http.Transport
|
||||
refCount int
|
||||
// tlsKey identifies the TLS trust settings (CA pool + InsecureSkipVerify)
|
||||
// this transport was built with, so the at-limit fallback only reuses a
|
||||
// transport whose TLS configuration matches the caller's.
|
||||
tlsKey string
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -53,19 +57,26 @@ func GetGlobalTransportPool() *SharedTransportPool {
|
||||
|
||||
// GetOrCreateTransport gets or creates a shared transport with the given config
|
||||
func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *http.Transport {
|
||||
// SECURITY FIX: Check client limit before creating new transport
|
||||
// SECURITY FIX: Check client limit before creating new transport.
|
||||
if atomic.LoadInt32(&p.clientCount) >= p.maxClients {
|
||||
// Return existing transport if limit reached
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
// At the client limit: only reuse a transport that was built for the
|
||||
// SAME config (same TLS trust store). refCount is mutated under the
|
||||
// write lock to avoid a data race, and a transport created for a
|
||||
// different configuration is never handed back — doing so could apply
|
||||
// the wrong (possibly verification-disabled) TLS settings to a request.
|
||||
want := tlsConfigKey(config)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for _, shared := range p.transports {
|
||||
if shared != nil && shared.transport != nil {
|
||||
if shared != nil && shared.transport != nil && shared.tlsKey == want {
|
||||
shared.refCount++
|
||||
shared.lastUsed = time.Now()
|
||||
return shared.transport
|
||||
}
|
||||
}
|
||||
// If no transport available, return nil (caller should handle)
|
||||
// No TLS-compatible transport available; return nil so the caller falls
|
||||
// back to a default, certificate-verifying transport rather than one
|
||||
// with a different (possibly verification-disabled) trust store.
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -125,6 +136,7 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
|
||||
transport: transport,
|
||||
refCount: 1,
|
||||
lastUsed: time.Now(),
|
||||
tlsKey: tlsConfigKey(config),
|
||||
}
|
||||
|
||||
return transport
|
||||
@@ -224,6 +236,18 @@ func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
|
||||
)
|
||||
}
|
||||
|
||||
// tlsConfigKey identifies only the TLS trust settings of a config — the CA pool
|
||||
// and the InsecureSkipVerify flag. Two configs with the same tlsConfigKey are
|
||||
// safe to serve from the same transport even if other (non-TLS) parameters such
|
||||
// as connection limits differ; configs with different TLS settings are not.
|
||||
func tlsConfigKey(config HTTPClientConfig) string {
|
||||
skip := "0"
|
||||
if config.InsecureSkipVerify {
|
||||
skip = "1"
|
||||
}
|
||||
return fmt.Sprintf("%p|%s", config.RootCAs, skip)
|
||||
}
|
||||
|
||||
// Cleanup closes all transports and stops the cleanup goroutine
|
||||
func (p *SharedTransportPool) Cleanup() {
|
||||
p.mu.Lock()
|
||||
|
||||
@@ -842,10 +842,18 @@ func TestWorkerPool_TaskPanic(t *testing.T) {
|
||||
t.Error("Timeout waiting for tasks")
|
||||
}
|
||||
|
||||
// Pool should still be functional
|
||||
metrics := pool.GetMetrics()
|
||||
if metrics["tasksFailed"].(int64) < 1 {
|
||||
// tasksFailed is incremented in the worker's deferred recover(), which runs
|
||||
// AFTER the panicking task's own `defer wg.Done()`. wg.Wait() above can
|
||||
// therefore return before the failure is recorded — reading the counter
|
||||
// immediately is a race that flakes on slow/contended CI runners. Poll until
|
||||
// the failure lands (or time out).
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for pool.GetMetrics()["tasksFailed"].(int64) < 1 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Error("Expected at least one failed task")
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+23
-1
@@ -155,12 +155,34 @@ func DetermineScheme(req *http.Request, forceHTTPS bool) string {
|
||||
// It checks X-Forwarded-Host header first (for proxy scenarios),
|
||||
// then falls back to req.Host.
|
||||
func DetermineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
if host := sanitizeForwardedHost(req.Header.Get("X-Forwarded-Host")); host != "" {
|
||||
return host
|
||||
}
|
||||
return req.Host
|
||||
}
|
||||
|
||||
// sanitizeForwardedHost returns a single, well-formed host from a (possibly
|
||||
// comma-separated) X-Forwarded-Host header, or "" if none is usable. It takes
|
||||
// only the first value and rejects whitespace and control characters, so a
|
||||
// crafted header cannot inject CRLF, smuggle a second host, or otherwise poison
|
||||
// the redirect URLs built from the result.
|
||||
func sanitizeForwardedHost(v string) string {
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
if i := strings.IndexByte(v, ','); i >= 0 {
|
||||
v = v[:i]
|
||||
}
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.IndexFunc(v, func(r rune) bool { return r < 0x20 || r == 0x7f || r == ' ' }) >= 0 {
|
||||
return ""
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// BuildFullURL constructs a URL from scheme, host, and path components.
|
||||
// It handles absolute URLs (returning them as-is) and ensures paths have leading slashes.
|
||||
func BuildFullURL(scheme, host, path string) string {
|
||||
|
||||
@@ -200,6 +200,22 @@ func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
|
||||
if k.Kid == "" {
|
||||
continue
|
||||
}
|
||||
// Skip keys that are not intended for signature verification.
|
||||
if k.Use != "" && k.Use != "sig" {
|
||||
continue
|
||||
}
|
||||
if len(k.KeyOps) > 0 {
|
||||
hasVerify := false
|
||||
for _, op := range k.KeyOps {
|
||||
if op == "verify" {
|
||||
hasVerify = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasVerify {
|
||||
continue
|
||||
}
|
||||
}
|
||||
var pub crypto.PublicKey
|
||||
var err error
|
||||
switch k.Kty {
|
||||
@@ -242,11 +258,11 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
|
||||
defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body) // Safe to ignore: reading error body for diagnostics
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024)) // Safe to ignore: reading error body for diagnostics
|
||||
return nil, fmt.Errorf("JWKS fetch failed with status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading JWKS response: %w", err)
|
||||
}
|
||||
|
||||
@@ -134,8 +134,11 @@ func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http
|
||||
expectedIssuer := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if iss != "" && iss != expectedIssuer {
|
||||
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
|
||||
// Require a matching issuer. An empty iss must be rejected too: accepting a
|
||||
// missing issuer would let an unauthenticated attacker force-logout any
|
||||
// session whose sid is known by simply omitting iss.
|
||||
if iss == "" || iss != expectedIssuer {
|
||||
t.logger.Errorf("Front-channel logout: issuer validation failed: got %q, expected %q", iss, expectedIssuer)
|
||||
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
+12
-5
@@ -125,10 +125,14 @@ func TestFrontchannelLogoutBasic(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Valid front-channel logout without issuer",
|
||||
// Front-channel logout MUST carry a matching issuer. A request
|
||||
// omitting iss is rejected so an unauthenticated attacker cannot
|
||||
// force-logout a session whose sid is known by simply leaving iss
|
||||
// out (audit rank 30).
|
||||
name: "Missing issuer is rejected",
|
||||
method: http.MethodGet,
|
||||
queryParams: map[string]string{"sid": "session456"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -463,8 +467,9 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Request to front-channel logout path with valid sid should succeed
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session", nil)
|
||||
// Request to front-channel logout path with valid sid + matching issuer
|
||||
// should succeed. The issuer is now required (audit rank 30), so supply it.
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session&iss=https://provider.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
@@ -1432,7 +1437,9 @@ func TestFrontchannelLogoutCacheControl(t *testing.T) {
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123", nil)
|
||||
// Issuer is now required (audit rank 30); supply a matching one so the
|
||||
// successful-logout cache headers can be asserted.
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123&iss=https://provider.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.handleFrontchannelLogout(rw, req)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -112,18 +113,18 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
config = CreateConfig()
|
||||
}
|
||||
|
||||
if config.SessionEncryptionKey == "" {
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
logger := NewLogger(config.LogLevel)
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
logger.Infof("Session encryption key is too short; using default key for analyzer")
|
||||
} else {
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
|
||||
// Fail closed on invalid configuration. Validate() enforces the security
|
||||
// constraints (required fields, HTTPS-only URLs, key length, excludedURLs
|
||||
// safety, rate-limit floor, audience format, ...) that were previously
|
||||
// unenforced because this constructor never called it. Crucially it rejects
|
||||
// an empty or too-short SessionEncryptionKey instead of silently
|
||||
// substituting a public hardcoded key, which would let an attacker forge
|
||||
// any session. Traefik's yaegi plugin analyzer supplies a valid key via
|
||||
// .traefik.yml testData, so it passes; only misconfigured deployments fail.
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
// Setup HTTP client
|
||||
caPool, err := config.loadCACertPool()
|
||||
@@ -235,7 +236,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
httpClient: httpClient,
|
||||
tokenHTTPClient: tokenHTTPClient,
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedUserDomains: createCaseInsensitiveStringMap(config.AllowedUserDomains),
|
||||
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
@@ -346,7 +347,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
|
||||
// Convert sessionMaxAge from seconds to duration (0 will use default 24 hours)
|
||||
sessionMaxAge := time.Duration(config.SessionMaxAge) * time.Second
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger) // Safe to ignore: session manager creation with fallback to defaults
|
||||
sessionManager, err := NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
return nil, fmt.Errorf("failed to create session manager: %w", err)
|
||||
}
|
||||
t.sessionManager = sessionManager
|
||||
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
|
||||
|
||||
// Initialize token resilience manager with default configuration
|
||||
@@ -437,6 +443,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
|
||||
// Add reference for this instance
|
||||
rm.AddReference(name)
|
||||
registerLiveInstance()
|
||||
|
||||
// Initialize metadata in a goroutine with proper tracking
|
||||
if t.goroutineWG != nil {
|
||||
@@ -514,13 +521,58 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
// Parameters:
|
||||
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
|
||||
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
// SSRF defense (audit ranks 3 & 4): a discovery document is attacker-
|
||||
// influenced when the provider or its TLS is compromised. Reject any
|
||||
// discovered endpoint pointed at a blocked address before the plugin issues
|
||||
// outbound requests to it, so it can never be used to reach the cloud
|
||||
// metadata service or an internal host.
|
||||
allowLoopback := false
|
||||
if pu, err := url.Parse(t.providerURL); err == nil {
|
||||
allowLoopback = isLoopbackHost(pu.Hostname())
|
||||
}
|
||||
sanitize := func(name, raw string) string {
|
||||
if err := t.validateDiscoveredEndpoint(raw, allowLoopback); err != nil {
|
||||
t.logger.Errorf("Ignoring discovered %s endpoint %q: %v", name, raw, err)
|
||||
return ""
|
||||
}
|
||||
return raw
|
||||
}
|
||||
metadata.JWKSURL = sanitize("jwks_uri", metadata.JWKSURL)
|
||||
metadata.AuthURL = sanitize("authorization", metadata.AuthURL)
|
||||
metadata.TokenURL = sanitize("token", metadata.TokenURL)
|
||||
metadata.RevokeURL = sanitize("revocation", metadata.RevokeURL)
|
||||
metadata.EndSessionURL = sanitize("end_session", metadata.EndSessionURL)
|
||||
metadata.RegistrationURL = sanitize("registration", metadata.RegistrationURL)
|
||||
metadata.IntrospectionURL = sanitize("introspection", metadata.IntrospectionURL)
|
||||
// The introspection request authenticates with the client secret via HTTP
|
||||
// Basic, so the endpoint must live on the same host as the operator-
|
||||
// configured provider; otherwise a poisoned discovery document could
|
||||
// exfiltrate the client secret to an attacker-controlled host.
|
||||
if metadata.IntrospectionURL != "" && t.providerURL != "" && !sameHost(metadata.IntrospectionURL, t.providerURL) {
|
||||
t.logger.Errorf("Ignoring introspection endpoint %q: host does not match configured providerURL", metadata.IntrospectionURL)
|
||||
metadata.IntrospectionURL = ""
|
||||
}
|
||||
|
||||
// Pin the discovered issuer to the operator-configured provider host. The
|
||||
// issuer is the trust anchor for JWT issuer validation, so a poisoned
|
||||
// discovery document advertising an attacker-chosen issuer must never be
|
||||
// stored. Real providers (Google, Azure, Keycloak, Okta, Auth0) keep the
|
||||
// issuer on the same host as the configured providerURL. On mismatch, leave
|
||||
// issuerURL empty/unchanged so downstream issuer validation fails closed
|
||||
// rather than trusting the attacker-chosen value.
|
||||
discoveredIssuer := metadata.Issuer
|
||||
if discoveredIssuer != "" && t.providerURL != "" && !sameHost(discoveredIssuer, t.providerURL) {
|
||||
t.logger.Errorf("Ignoring discovered issuer %q: host does not match configured providerURL", discoveredIssuer)
|
||||
discoveredIssuer = ""
|
||||
}
|
||||
|
||||
t.metadataMu.Lock()
|
||||
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.issuerURL = discoveredIssuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662)
|
||||
@@ -533,7 +585,7 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
// Publish the read-mostly URL bundle atomically. Hot-path readers Load
|
||||
// this directly instead of acquiring metadataMu.RLock per request.
|
||||
t.metadataSnapshot.Store(&MetadataSnapshot{
|
||||
IssuerURL: metadata.Issuer,
|
||||
IssuerURL: discoveredIssuer,
|
||||
JWKSURL: metadata.JWKSURL,
|
||||
TokenURL: metadata.TokenURL,
|
||||
AuthURL: metadata.AuthURL,
|
||||
|
||||
@@ -194,6 +194,7 @@ func TestGoroutineLeakPrevention_MultipleInstances(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(ctx, nil, config, "test")
|
||||
if err != nil {
|
||||
@@ -322,6 +323,7 @@ func TestGoroutineLeakPrevention_BackgroundTaskCleanup(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(ctx, nil, config, "test")
|
||||
if err != nil {
|
||||
|
||||
+56
-38
@@ -26,38 +26,47 @@ func TestInitializeMetadata(t *testing.T) {
|
||||
name: "successful metadata initialization",
|
||||
providerURL: "",
|
||||
setupMock: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer must share the host with providerURL (the httptest
|
||||
// server), otherwise the discovery doc is rejected as poisoned
|
||||
// (audit ranks 21/22). Real providers keep issuer + endpoints on
|
||||
// the same host, so derive them all from the server URL.
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://provider.example.com",
|
||||
AuthURL: "https://provider.example.com/auth",
|
||||
TokenURL: "https://provider.example.com/token",
|
||||
JWKSURL: "https://provider.example.com/jwks",
|
||||
RevokeURL: "https://provider.example.com/revoke",
|
||||
EndSessionURL: "https://provider.example.com/logout",
|
||||
Issuer: srv.URL,
|
||||
AuthURL: srv.URL + "/auth",
|
||||
TokenURL: srv.URL + "/token",
|
||||
JWKSURL: srv.URL + "/jwks",
|
||||
RevokeURL: srv.URL + "/revoke",
|
||||
EndSessionURL: srv.URL + "/logout",
|
||||
})
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
},
|
||||
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
|
||||
if oidc.authURL != "https://provider.example.com/auth" {
|
||||
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
|
||||
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
|
||||
}
|
||||
if oidc.tokenURL != "https://provider.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
|
||||
}
|
||||
if oidc.jwksURL != "https://provider.example.com/jwks" {
|
||||
if oidc.jwksURL == "" || !strings.HasSuffix(oidc.jwksURL, "/jwks") {
|
||||
t.Errorf("expected jwksURL to be set, got %s", oidc.jwksURL)
|
||||
}
|
||||
if oidc.revocationURL != "https://provider.example.com/revoke" {
|
||||
if oidc.revocationURL == "" || !strings.HasSuffix(oidc.revocationURL, "/revoke") {
|
||||
t.Errorf("expected revocationURL to be set, got %s", oidc.revocationURL)
|
||||
}
|
||||
if oidc.endSessionURL != "https://provider.example.com/logout" {
|
||||
if oidc.endSessionURL == "" || !strings.HasSuffix(oidc.endSessionURL, "/logout") {
|
||||
t.Errorf("expected endSessionURL to be set, got %s", oidc.endSessionURL)
|
||||
}
|
||||
if oidc.issuerURL == "" {
|
||||
t.Errorf("expected issuerURL to be pinned to provider host, got empty")
|
||||
}
|
||||
},
|
||||
wantPanic: false,
|
||||
},
|
||||
@@ -116,24 +125,27 @@ func TestInitializeMetadata(t *testing.T) {
|
||||
name: "partial metadata response",
|
||||
providerURL: "",
|
||||
setupMock: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Only return some fields
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"issuer": "https://partial.example.com",
|
||||
"authorization_endpoint": "https://partial.example.com/auth",
|
||||
"token_endpoint": "https://partial.example.com/token",
|
||||
"issuer": srv.URL,
|
||||
"authorization_endpoint": srv.URL + "/auth",
|
||||
"token_endpoint": srv.URL + "/token",
|
||||
// Missing jwks_uri, revocation_endpoint, end_session_endpoint
|
||||
})
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
},
|
||||
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
|
||||
if oidc.authURL != "https://partial.example.com/auth" {
|
||||
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
|
||||
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
|
||||
}
|
||||
if oidc.tokenURL != "https://partial.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
|
||||
}
|
||||
// JWKS URL and others may be empty
|
||||
@@ -198,20 +210,22 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
|
||||
requestCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
requestCount++
|
||||
mu.Unlock()
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://concurrent.example.com",
|
||||
AuthURL: "https://concurrent.example.com/auth",
|
||||
TokenURL: "https://concurrent.example.com/token",
|
||||
JWKSURL: "https://concurrent.example.com/jwks",
|
||||
RevokeURL: "https://concurrent.example.com/revoke",
|
||||
EndSessionURL: "https://concurrent.example.com/logout",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
RevokeURL: server.URL + "/revoke",
|
||||
EndSessionURL: server.URL + "/logout",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -250,7 +264,7 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
|
||||
oidc.initializeMetadata(server.URL)
|
||||
|
||||
// Verify initialization
|
||||
if oidc.tokenURL != "https://concurrent.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set")
|
||||
}
|
||||
}()
|
||||
@@ -342,17 +356,19 @@ func TestProviderDetection(t *testing.T) {
|
||||
// TestInitializationWaiting tests waiting for initialization to complete
|
||||
func TestInitializationWaiting(t *testing.T) {
|
||||
t.Run("wait for initialization completion", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Delay response to simulate slow initialization
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://slow.example.com",
|
||||
AuthURL: "https://slow.example.com/auth",
|
||||
TokenURL: "https://slow.example.com/token",
|
||||
JWKSURL: "https://slow.example.com/jwks",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -389,7 +405,7 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
select {
|
||||
case <-oidc.initComplete:
|
||||
// Success
|
||||
if oidc.tokenURL != "https://slow.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Error("expected tokenURL to be set after initialization")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
@@ -398,17 +414,19 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("multiple waiters for initialization", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Delay to ensure multiple waiters
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://multi.example.com",
|
||||
AuthURL: "https://multi.example.com/auth",
|
||||
TokenURL: "https://multi.example.com/token",
|
||||
JWKSURL: "https://multi.example.com/jwks",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -453,7 +471,7 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
select {
|
||||
case <-oidc.initComplete:
|
||||
// All waiters should see the same initialized state
|
||||
if oidc.tokenURL != "https://multi.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("waiter %d: expected tokenURL to be set", id)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
|
||||
+68
-44
@@ -1875,14 +1875,14 @@ func TestHandleLogout(t *testing.T) {
|
||||
},
|
||||
endSessionURL: "",
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
expectedURL: "/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
name: "Logout with empty session",
|
||||
setupSession: func(session *SessionData) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
expectedURL: "/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
@@ -2349,19 +2349,22 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
t.Skip("Skipping test in short mode")
|
||||
}
|
||||
|
||||
// Create mock provider metadata server
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create mock provider metadata server. Issuer + endpoints must share the
|
||||
// host with ProviderURL (the httptest server), otherwise the discovery doc
|
||||
// is rejected as poisoned (audit ranks 21/22). Derive them from the server.
|
||||
var mockServer *httptest.Server
|
||||
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
RevokeURL: "https://test-issuer.com/revoke",
|
||||
EndSessionURL: "https://test-issuer.com/end-session",
|
||||
Issuer: mockServer.URL,
|
||||
AuthURL: mockServer.URL + "/auth",
|
||||
TokenURL: mockServer.URL + "/token",
|
||||
JWKSURL: mockServer.URL + "/jwks",
|
||||
RevokeURL: mockServer.URL + "/revoke",
|
||||
EndSessionURL: mockServer.URL + "/end-session",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
@@ -2374,6 +2377,7 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create multiple middleware instances
|
||||
@@ -2414,18 +2418,20 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
t.Fatalf("Middleware instance %d failed to initialize", i)
|
||||
}
|
||||
|
||||
// Verify each instance has its own unique configuration
|
||||
if m.issuerURL != "https://test-issuer.com" {
|
||||
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
|
||||
// Verify each instance has its own unique configuration. Issuer is now
|
||||
// pinned to the provider host (audit ranks 21/22), so it equals the
|
||||
// mock server URL rather than a fixed literal.
|
||||
if m.issuerURL != mockServer.URL {
|
||||
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, mockServer.URL, m.issuerURL)
|
||||
}
|
||||
if m.authURL != "https://test-issuer.com/auth" {
|
||||
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
|
||||
if m.authURL != mockServer.URL+"/auth" {
|
||||
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, mockServer.URL+"/auth", m.authURL)
|
||||
}
|
||||
if m.tokenURL != "https://test-issuer.com/token" {
|
||||
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
|
||||
if m.tokenURL != mockServer.URL+"/token" {
|
||||
t.Errorf("Instance %d: Expected token URL %s, got %s", i, mockServer.URL+"/token", m.tokenURL)
|
||||
}
|
||||
if m.jwksURL != "https://test-issuer.com/jwks" {
|
||||
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
|
||||
if m.jwksURL != mockServer.URL+"/jwks" {
|
||||
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, mockServer.URL+"/jwks", m.jwksURL)
|
||||
}
|
||||
if m.redirURLPath != routes[i]+"/callback" {
|
||||
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
|
||||
@@ -2439,15 +2445,16 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
|
||||
m.ServeHTTP(rr, req)
|
||||
|
||||
// Should redirect to auth URL since not authenticated
|
||||
// Should redirect (302) to the auth flow since not authenticated. The
|
||||
// absolute auth URL is not asserted here: with issuer pinning (audit
|
||||
// ranks 21/22) the discovery host equals the httptest server host,
|
||||
// which is loopback, so buildAuthURL's SSRF guard legitimately refuses
|
||||
// to emit a loopback authorization URL in this test environment. The
|
||||
// per-instance auth/token/jwks/issuer URLs were already verified above;
|
||||
// here we only confirm each instance independently triggers a redirect.
|
||||
if rr.Code != http.StatusFound {
|
||||
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
|
||||
}
|
||||
|
||||
location := rr.Header().Get("Location")
|
||||
if !strings.Contains(location, "https://test-issuer.com/auth") {
|
||||
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2460,33 +2467,43 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create two mock provider metadata servers simulating different Keycloak realms
|
||||
realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer + endpoints must share the host with each realm's ProviderURL
|
||||
// (the httptest server), otherwise the discovery doc is rejected as
|
||||
// poisoned (audit ranks 21/22). Keep the distinguishing /realms/realmN
|
||||
// path so the per-realm isolation assertions below still hold, but base
|
||||
// the host on the server URL — which is exactly what a same-host Keycloak
|
||||
// deployment looks like.
|
||||
var realm1Server *httptest.Server
|
||||
realm1Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
base := realm1Server.URL + "/realms/realm1"
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://keycloak.example.com/realms/realm1",
|
||||
AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth",
|
||||
TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token",
|
||||
JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs",
|
||||
EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout",
|
||||
Issuer: base,
|
||||
AuthURL: base + "/protocol/openid-connect/auth",
|
||||
TokenURL: base + "/protocol/openid-connect/token",
|
||||
JWKSURL: base + "/protocol/openid-connect/certs",
|
||||
EndSessionURL: base + "/protocol/openid-connect/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
defer realm1Server.Close()
|
||||
|
||||
realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var realm2Server *httptest.Server
|
||||
realm2Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
base := realm2Server.URL + "/realms/realm2"
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://keycloak.example.com/realms/realm2",
|
||||
AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth",
|
||||
TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token",
|
||||
JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs",
|
||||
EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout",
|
||||
Issuer: base,
|
||||
AuthURL: base + "/protocol/openid-connect/auth",
|
||||
TokenURL: base + "/protocol/openid-connect/token",
|
||||
JWKSURL: base + "/protocol/openid-connect/certs",
|
||||
EndSessionURL: base + "/protocol/openid-connect/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
@@ -2500,6 +2517,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
CallbackURL: "/realm1/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
CookiePrefix: "_oidc_realm1_",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Config for realm2
|
||||
@@ -2510,6 +2528,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
CallbackURL: "/realm2/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
CookiePrefix: "_oidc_realm2_",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware instances for both realms
|
||||
@@ -2608,8 +2627,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
providerAvailable := false
|
||||
var mu sync.Mutex
|
||||
|
||||
// Create mock provider that initially fails, then becomes available
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create mock provider that initially fails, then becomes available.
|
||||
// Issuer + endpoints must share the host with ProviderURL (audit ranks
|
||||
// 21/22), so derive them from the server URL.
|
||||
var mockServer *httptest.Server
|
||||
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
available := providerAvailable
|
||||
mu.Unlock()
|
||||
@@ -2621,11 +2643,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
EndSessionURL: "https://test-issuer.com/logout",
|
||||
Issuer: mockServer.URL,
|
||||
AuthURL: mockServer.URL + "/auth",
|
||||
TokenURL: mockServer.URL + "/token",
|
||||
JWKSURL: mockServer.URL + "/jwks",
|
||||
EndSessionURL: mockServer.URL + "/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
return
|
||||
@@ -2640,6 +2662,7 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware while provider is unavailable
|
||||
@@ -4552,6 +4575,7 @@ func TestNewWithScopeAppending(t *testing.T) {
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
Scopes: tc.configScopes,
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware instance
|
||||
|
||||
@@ -1652,6 +1652,7 @@ func TestGoroutineLeaks(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client"
|
||||
config.ClientSecret = "test-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(context.Background(), nil, config, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
+11
-1
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -141,10 +142,19 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
|
||||
}
|
||||
|
||||
var metadata ProviderMetadata
|
||||
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode metadata: %w", err)
|
||||
}
|
||||
|
||||
// Pin the advertised issuer to the configured provider host. The issuer is
|
||||
// the trust anchor for JWT issuer validation; rejecting a mismatch here
|
||||
// ensures a poisoned discovery document advertising an attacker-chosen
|
||||
// issuer is never cached or returned. Real providers (Google, Azure,
|
||||
// Keycloak, Okta, Auth0) keep the issuer on the same host as providerURL.
|
||||
if metadata.Issuer != "" && !sameHost(metadata.Issuer, providerURL) {
|
||||
return nil, fmt.Errorf("discovery issuer %q host does not match provider %q", metadata.Issuer, providerURL)
|
||||
}
|
||||
|
||||
// Cache for 1 hour by default
|
||||
if err := mc.Set(providerURL, &metadata, 1*time.Hour); err != nil {
|
||||
mc.logger.Errorf("Failed to cache metadata: %v", err)
|
||||
|
||||
+80
-7
@@ -472,6 +472,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - req: The HTTP request to process.
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
//
|
||||
// processAuthorizedRequestRS is the requestState-aware variant of
|
||||
// processAuthorizedRequest. It reads SessionData fields from the captured
|
||||
// snapshot in rs instead of calling session.GetX() (each of which acquires
|
||||
@@ -675,6 +676,44 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
//
|
||||
// Session persistence is the CALLER's responsibility — it must happen before
|
||||
// this function so Set-Cookie reaches the response.
|
||||
// headerTemplateMaxLen bounds the length of a rendered operator-defined header
|
||||
// template before it is forwarded downstream. Generous enough for an
|
||||
// "Authorization: Bearer <jwt>" value but small enough to reject obviously
|
||||
// abusive output. Matches the input-validation default header cap (8KB).
|
||||
const headerTemplateMaxLen = 8192
|
||||
|
||||
// headerClaimMaxLen returns the maximum accepted length for a claim-derived
|
||||
// header value (principal identifier, group, role). Reuses the operator-
|
||||
// configured identifier cap (default 256) so a single setting governs both
|
||||
// auth paths; falls back to 256 when unset.
|
||||
func (t *TraefikOidc) headerClaimMaxLen() int {
|
||||
if t.maxIdentifierLength > 0 {
|
||||
return t.maxIdentifierLength
|
||||
}
|
||||
return 256
|
||||
}
|
||||
|
||||
// sanitizeHeaderClaimList drops any group/role value that fails claim
|
||||
// sanitization (control chars, bidi-override runes, the , ; = delimiters, or an
|
||||
// over-long value) and returns the surviving values. Failing closed on a bad
|
||||
// entry prevents header injection and stops an embedded comma from injecting
|
||||
// extra entries into the comma-joined header. headerName is used only for
|
||||
// debug logging — the value is never logged.
|
||||
func (t *TraefikOidc) sanitizeHeaderClaimList(values []string, headerName string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
safe := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
if clean, ok := sanitizeHeaderClaimValue(v, t.headerClaimMaxLen()); ok {
|
||||
safe = append(safe, clean)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping %s entry: value failed claim sanitization", headerName)
|
||||
}
|
||||
}
|
||||
return safe
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Request, p *principal) {
|
||||
var (
|
||||
groups, roles []string
|
||||
@@ -692,11 +731,18 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
|
||||
return
|
||||
}
|
||||
if extractErr == nil {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
// Sanitize each group/role before it is joined into a comma-
|
||||
// delimited header. The cookie/session path does not otherwise
|
||||
// sanitize claim-derived values (the bearer path sanitizes its
|
||||
// identifier at construction), so a control char would enable
|
||||
// header injection and an embedded comma would inject extra
|
||||
// entries into the comma-joined header. Fail closed: drop any
|
||||
// value that does not pass.
|
||||
if safeGroups := t.sanitizeHeaderClaimList(groups, "X-User-Groups"); len(safeGroups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(safeGroups, ","))
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
if safeRoles := t.sanitizeHeaderClaimList(roles, "X-User-Roles"); len(safeRoles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(safeRoles, ","))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -717,12 +763,26 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", p.Identifier)
|
||||
// Sanitize the principal identifier before injecting it into headers. The
|
||||
// bearer path already sanitizes its identifier at construction; the
|
||||
// cookie/session path does not, so a claim carrying control chars, bidi-
|
||||
// override runes, or , ; = could inject or spoof header content. Fail
|
||||
// closed: drop the identifier header(s) rather than forward a tainted value.
|
||||
safeIdentifier, identifierOK := sanitizeHeaderClaimValue(p.Identifier, t.headerClaimMaxLen())
|
||||
if identifierOK {
|
||||
req.Header.Set("X-Forwarded-User", safeIdentifier)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping X-Forwarded-User header: identifier failed claim sanitization")
|
||||
}
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", p.Identifier)
|
||||
if identifierOK {
|
||||
req.Header.Set("X-Auth-Request-User", safeIdentifier)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping X-Auth-Request-User header: identifier failed claim sanitization")
|
||||
}
|
||||
if p.IDToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", p.IDToken)
|
||||
}
|
||||
@@ -747,8 +807,21 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
|
||||
continue
|
||||
}
|
||||
headerValue := buf.String()
|
||||
// Sanitize the rendered output: template inputs are claim-derived
|
||||
// and attacker-influenceable, so reject control chars (header
|
||||
// injection), bidi-override runes, the , ; = delimiters, and an
|
||||
// over-long value. Fail closed by dropping the header rather than
|
||||
// forwarding a tainted value. Do not log the value (it commonly
|
||||
// carries the access token); log only name + reason.
|
||||
if reason := headerValueReason(headerValue, headerTemplateMaxLen); reason != "" {
|
||||
t.logger.Debugf("Dropping templated header %s: value failed sanitization (%s)", headerName, reason)
|
||||
continue
|
||||
}
|
||||
req.Header.Set(headerName, headerValue)
|
||||
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
||||
// Do not log the value: templated headers commonly carry the access
|
||||
// token (e.g. "Authorization: Bearer {{.AccessToken}}"), and logging
|
||||
// it — even at debug — leaks credentials into logs.
|
||||
t.logger.Debugf("Set templated header %s (%d bytes)", headerName, len(headerValue))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// TestRank1_SessionCookieIsEncrypted verifies that the session cookie payload is
|
||||
// AES-encrypted, not merely HMAC-signed. Regression test for the audit finding
|
||||
// "session cookies signed but NOT encrypted": a single key left the stored OIDC
|
||||
// tokens recoverable in plaintext from the raw cookie bytes.
|
||||
func TestRank1_SessionCookieIsEncrypted(t *testing.T) {
|
||||
const secret = "a-sufficiently-long-session-encryption-key"
|
||||
authKey, encKey := deriveCookieKeys(secret)
|
||||
if len(authKey) != 64 || len(encKey) != 32 {
|
||||
t.Fatalf("expected 64-byte auth key and 32-byte enc key, got %d/%d", len(authKey), len(encKey))
|
||||
}
|
||||
if string(authKey) == string(encKey) {
|
||||
t.Fatal("authentication and encryption keys must be independent")
|
||||
}
|
||||
|
||||
const marker = "SUPER-SECRET-ACCESS-TOKEN-marker-value"
|
||||
|
||||
// Encode a session through the same two-key store the production code now
|
||||
// builds (see NewSessionManager).
|
||||
store := sessions.NewCookieStore(authKey, encKey)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
sess, err := store.New(req, "session")
|
||||
if err != nil {
|
||||
t.Fatalf("store.New failed: %v", err)
|
||||
}
|
||||
sess.Values["tok"] = marker
|
||||
if err := sess.Save(req, rec); err != nil {
|
||||
t.Fatalf("session save failed: %v", err)
|
||||
}
|
||||
|
||||
var cookie *http.Cookie
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
if c.Name == "session" {
|
||||
cookie = c
|
||||
}
|
||||
}
|
||||
if cookie == nil {
|
||||
t.Fatal("no session cookie was set")
|
||||
}
|
||||
|
||||
// The secret token must never appear in plaintext in the cookie value.
|
||||
if strings.Contains(cookie.Value, marker) {
|
||||
t.Error("marker token found in plaintext inside the session cookie value")
|
||||
}
|
||||
|
||||
// A store holding only the authentication key (the previous behavior)
|
||||
// must NOT be able to read the encrypted cookie — proving the payload is
|
||||
// genuinely encrypted, not just signed.
|
||||
signedOnly := sessions.NewCookieStore(authKey)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2.AddCookie(cookie)
|
||||
if _, derr := signedOnly.Get(req2, "session"); derr == nil {
|
||||
t.Error("encrypted cookie should not be decodable without the encryption key")
|
||||
}
|
||||
|
||||
// The full two-key store round-trips correctly.
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req3.AddCookie(cookie)
|
||||
rt, derr := store.Get(req3, "session")
|
||||
if derr != nil {
|
||||
t.Fatalf("round-trip decode failed: %v", derr)
|
||||
}
|
||||
if got, _ := rt.Values["tok"].(string); got != marker {
|
||||
t.Errorf("round-trip mismatch: got %q want %q", got, marker)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank2And6_InvalidConfigFailsClosed verifies that NewWithContext now calls
|
||||
// Config.Validate() and fails closed on an empty or too-short session
|
||||
// encryption key instead of silently substituting a public hardcoded key, and
|
||||
// rejects other missing required fields. Regression test for "hardcoded default
|
||||
// encryption key" + "Config.Validate() never called in production path".
|
||||
func TestRank2And6_InvalidConfigFailsClosed(t *testing.T) {
|
||||
base := func() *Config {
|
||||
return &Config{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "this-is-a-valid-session-key-32b!",
|
||||
RateLimit: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity: a fully valid config still constructs.
|
||||
p, err := NewWithContext(context.Background(), base(), nil, "valid")
|
||||
if err != nil {
|
||||
t.Fatalf("valid config should construct, got: %v", err)
|
||||
}
|
||||
if p != nil {
|
||||
p.Close()
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
}{
|
||||
{"empty key", func(c *Config) { c.SessionEncryptionKey = "" }},
|
||||
{"short key", func(c *Config) { c.SessionEncryptionKey = "tooshort" }},
|
||||
{"missing providerURL", func(c *Config) { c.ProviderURL = "" }},
|
||||
{"missing callbackURL", func(c *Config) { c.CallbackURL = "" }},
|
||||
{"plaintext remote providerURL", func(c *Config) { c.ProviderURL = "http://accounts.google.com" }},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := base()
|
||||
tc.mutate(c)
|
||||
plugin, err := NewWithContext(context.Background(), c, nil, tc.name)
|
||||
if err == nil {
|
||||
if plugin != nil {
|
||||
plugin.Close()
|
||||
}
|
||||
t.Errorf("expected NewWithContext to reject config (%s), but it succeeded", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank3_DiscoveredEndpointSSRFGuard verifies that endpoints from the
|
||||
// provider discovery document are screened against SSRF targets before use.
|
||||
func TestRank3_DiscoveredEndpointSSRFGuard(t *testing.T) {
|
||||
tr := &TraefikOidc{}
|
||||
|
||||
blocked := []string{
|
||||
"http://169.254.169.254/latest/meta-data/", // cloud metadata (link-local)
|
||||
"http://[fe80::1]/jwks", // IPv6 link-local
|
||||
"http://10.0.0.5/jwks", // private
|
||||
"http://192.168.1.10/jwks", // private
|
||||
"http://127.0.0.1/jwks", // loopback (allowLoopback=false)
|
||||
"ftp://example.com/jwks", // disallowed scheme
|
||||
}
|
||||
for _, u := range blocked {
|
||||
if err := tr.validateDiscoveredEndpoint(u, false); err == nil {
|
||||
t.Errorf("expected discovered endpoint %q to be rejected", u)
|
||||
}
|
||||
}
|
||||
|
||||
allowed := []string{
|
||||
"https://accounts.google.com/o/oauth2/v3/certs",
|
||||
"https://www.googleapis.com/oauth2/v3/certs", // cross-domain JWKS must stay allowed
|
||||
"", // empty optional endpoint
|
||||
}
|
||||
for _, u := range allowed {
|
||||
if err := tr.validateDiscoveredEndpoint(u, false); err != nil {
|
||||
t.Errorf("expected discovered endpoint %q to be allowed, got %v", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Loopback is allowed only when the provider itself is loopback (dev/test).
|
||||
if err := tr.validateDiscoveredEndpoint("http://127.0.0.1:8080/jwks", true); err != nil {
|
||||
t.Errorf("loopback endpoint should be allowed when allowLoopback=true: %v", err)
|
||||
}
|
||||
// Private addresses are allowed when explicitly opted in.
|
||||
trPriv := &TraefikOidc{allowPrivateIPAddresses: true}
|
||||
if err := trPriv.validateDiscoveredEndpoint("http://10.0.0.5/jwks", false); err != nil {
|
||||
t.Errorf("private endpoint should be allowed when allowPrivateIPAddresses=true: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank4_IntrospectionHostPin verifies the host-equality check used to pin
|
||||
// the credential-bearing introspection endpoint to the configured provider.
|
||||
func TestRank4_IntrospectionHostPin(t *testing.T) {
|
||||
if !sameHost("https://kc.example.com/realms/x", "https://kc.example.com/realms/x/protocol/openid-connect/token/introspect") {
|
||||
t.Error("introspection on the same host as the provider should be accepted")
|
||||
}
|
||||
if sameHost("https://kc.example.com", "https://evil.example.net/introspect") {
|
||||
t.Error("introspection on a different host must be rejected")
|
||||
}
|
||||
if sameHost("", "https://kc.example.com") || sameHost("https://kc.example.com", "") {
|
||||
t.Error("empty URL must not be treated as a host match")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank5_OpenRedirectNeutralized verifies the helper the callback now applies
|
||||
// to the stored incoming path forces a host-relative redirect target.
|
||||
func TestRank5_OpenRedirectNeutralized(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"//evil.com/x": "/evil.com/x",
|
||||
`/\evil.com`: "/evil.com",
|
||||
"/legit/path": "/legit/path",
|
||||
}
|
||||
for in, want := range cases {
|
||||
got := normalizeLogoutPath(in)
|
||||
if got != want {
|
||||
t.Errorf("normalizeLogoutPath(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
if strings.HasPrefix(got, "//") || strings.HasPrefix(got, `/\`) {
|
||||
t.Errorf("normalizeLogoutPath(%q) = %q is still protocol-relative", in, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank14_ExcludedURLSegmentBoundary verifies excluded-URL matching is
|
||||
// anchored at path-segment boundaries and cannot be widened into a bypass.
|
||||
func TestRank14_ExcludedURLSegmentBoundary(t *testing.T) {
|
||||
if !pathExcluded("/public", "/public") {
|
||||
t.Error("exact match should be excluded")
|
||||
}
|
||||
if !pathExcluded("/public/page", "/public") {
|
||||
t.Error("sub-path should be excluded")
|
||||
}
|
||||
if pathExcluded("/publicsecret", "/public") {
|
||||
t.Error("/publicsecret must NOT be excluded by /public")
|
||||
}
|
||||
if pathExcluded("/public-admin", "/public") {
|
||||
t.Error("/public-admin must NOT be excluded by /public")
|
||||
}
|
||||
if !pathExcluded("/health", "/health/") {
|
||||
t.Error("trailing-slash config should still match the exact path")
|
||||
}
|
||||
if pathExcluded("/anything", "/") {
|
||||
t.Error("root exclusion must not match arbitrary paths")
|
||||
}
|
||||
if !pathExcluded("/", "/") {
|
||||
t.Error("root exclusion should match the root path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank15_ForwardedHostSanitized verifies a crafted X-Forwarded-Host cannot
|
||||
// inject CRLF, smuggle a second host, or otherwise poison the derived host.
|
||||
func TestRank15_ForwardedHostSanitized(t *testing.T) {
|
||||
mk := func(xfh string) *http.Request {
|
||||
r := httptest.NewRequest(http.MethodGet, "http://real.example.com/x", nil)
|
||||
r.Host = "real.example.com"
|
||||
if xfh != "" {
|
||||
r.Header.Set("X-Forwarded-Host", xfh)
|
||||
}
|
||||
return r
|
||||
}
|
||||
if got := utils.DetermineHost(mk("ext.example.com")); got != "ext.example.com" {
|
||||
t.Errorf("clean X-Forwarded-Host should be honored, got %q", got)
|
||||
}
|
||||
if got := utils.DetermineHost(mk("a.example.com, evil.com")); got != "a.example.com" {
|
||||
t.Errorf("multi-value X-Forwarded-Host should use first host only, got %q", got)
|
||||
}
|
||||
for _, bad := range []string{"evil.com\r\nSet-Cookie: x=1", "evil.com /x", " "} {
|
||||
if got := utils.DetermineHost(mk(bad)); got != "real.example.com" {
|
||||
t.Errorf("malformed X-Forwarded-Host %q should fall back to req.Host, got %q", bad, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank11_TransportPoolTLSIsolationAtLimit verifies that, once the client
|
||||
// limit is reached, the transport pool reuses an existing transport only when
|
||||
// its TLS settings match the caller's, and never hands back a transport built
|
||||
// with different TLS trust settings.
|
||||
func TestRank11_TransportPoolTLSIsolationAtLimit(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
strict := DefaultHTTPClientConfig() // InsecureSkipVerify = false
|
||||
t1 := pool.GetOrCreateTransport(strict)
|
||||
if t1 == nil {
|
||||
t.Fatal("expected a transport for the strict config")
|
||||
}
|
||||
|
||||
// Saturate the client limit so subsequent calls hit the fallback path.
|
||||
atomic.StoreInt32(&pool.clientCount, pool.maxClients)
|
||||
|
||||
// Same TLS settings, different (non-TLS) connection limit: safe to reuse.
|
||||
sameTLS := DefaultHTTPClientConfig()
|
||||
sameTLS.MaxConnsPerHost = 99
|
||||
if got := pool.GetOrCreateTransport(sameTLS); got != t1 {
|
||||
t.Error("at the limit a TLS-compatible config should reuse the existing transport")
|
||||
}
|
||||
|
||||
// Different TLS settings (InsecureSkipVerify): must NOT reuse the strict
|
||||
// transport — returning nil lets the caller fall back to a verifying default.
|
||||
insecure := DefaultHTTPClientConfig()
|
||||
insecure.InsecureSkipVerify = true
|
||||
if got := pool.GetOrCreateTransport(insecure); got == t1 {
|
||||
t.Error("at the limit a config with different TLS settings must not reuse the strict transport")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank9_RedisFingerprint verifies divergent explicit Redis backends produce
|
||||
// distinct fingerprints (used to warn about ignored cache config), while an
|
||||
// absent or disabled Redis yields the empty (no-warning) fingerprint.
|
||||
func TestRank9_RedisFingerprint(t *testing.T) {
|
||||
if redisFingerprint(nil) != "" {
|
||||
t.Error("nil config should yield an empty fingerprint")
|
||||
}
|
||||
if redisFingerprint(&Config{}) != "" {
|
||||
t.Error("config without Redis should yield an empty fingerprint")
|
||||
}
|
||||
if redisFingerprint(&Config{Redis: &RedisConfig{Enabled: false, Address: "a:6379"}}) != "" {
|
||||
t.Error("disabled Redis should yield an empty fingerprint")
|
||||
}
|
||||
a := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "a:6379", KeyPrefix: "p"}})
|
||||
b := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "b:6379", KeyPrefix: "p"}})
|
||||
if a == "" || a == b {
|
||||
t.Errorf("distinct enabled backends must produce distinct non-empty fingerprints (%q vs %q)", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank10_TokenTypeCacheKeyNoCollision verifies that two different tokens
|
||||
// sharing the same 32-character JWT header prefix are classified independently.
|
||||
// The previous 32-char cache key would have collided and mis-classified them.
|
||||
func TestRank10_TokenTypeCacheKeyNoCollision(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
tokenTypeCache: NewCache(),
|
||||
suppressDiagnosticLogs: true,
|
||||
clientID: "client",
|
||||
}
|
||||
// A header prefix longer than 32 chars, shared by both tokens.
|
||||
prefix := "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ"
|
||||
idJWT := &JWT{Header: map[string]interface{}{}, Claims: map[string]interface{}{"nonce": "n"}}
|
||||
accessJWT := &JWT{Header: map[string]interface{}{"typ": "at+jwt"}, Claims: map[string]interface{}{}}
|
||||
|
||||
if !tr.detectTokenType(idJWT, prefix+".id.sig") {
|
||||
t.Error("token with a nonce claim should be detected as an ID token")
|
||||
}
|
||||
if tr.detectTokenType(accessJWT, prefix+".access.sig") {
|
||||
t.Error("access token (typ=at+jwt) must not be mis-classified as ID despite the shared 32-char prefix")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank12_LiveInstanceCounter verifies the process-global instance counter
|
||||
// that gates teardown of shared singleton tasks.
|
||||
func TestRank12_LiveInstanceCounter(t *testing.T) {
|
||||
start := atomic.LoadInt32(&liveInstanceCount)
|
||||
registerLiveInstance()
|
||||
registerLiveInstance()
|
||||
if got := atomic.LoadInt32(&liveInstanceCount); got != start+2 {
|
||||
t.Fatalf("expected %d live instances, got %d", start+2, got)
|
||||
}
|
||||
if rem := unregisterLiveInstance(); rem != start+1 {
|
||||
t.Errorf("expected %d remaining, got %d", start+1, rem)
|
||||
}
|
||||
if rem := unregisterLiveInstance(); rem != start {
|
||||
t.Errorf("expected %d remaining, got %d", start, rem)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank13_CookieMaxAgeMatchesSessionLifetime verifies the cookie store's
|
||||
// MaxAge (which bounds both the cookie Max-Age and the codec's cryptographic
|
||||
// timestamp validity) is bound to the configured session lifetime rather than
|
||||
// gorilla's 30-day default.
|
||||
func TestRank13_CookieMaxAgeMatchesSessionLifetime(t *testing.T) {
|
||||
maxAge := 2 * time.Hour
|
||||
sm, err := NewSessionManager(strings.Repeat("k", 40), false, "", "", maxAge, NewLogger("error"))
|
||||
if err != nil {
|
||||
t.Fatalf("NewSessionManager failed: %v", err)
|
||||
}
|
||||
defer sm.cancel()
|
||||
|
||||
cs, ok := sm.store.(*sessions.CookieStore)
|
||||
if !ok {
|
||||
t.Fatal("session store is not a *sessions.CookieStore")
|
||||
}
|
||||
if got := cs.Options.MaxAge; got != int(maxAge.Seconds()) {
|
||||
t.Errorf("cookie store MaxAge = %d, want %d (bound to sessionMaxAge)", got, int(maxAge.Seconds()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank33And34_HeaderSanitizationDistinction verifies the two header sinks
|
||||
// use the right strictness: free-form templated header VALUES (rank 34) permit
|
||||
// , ; = (e.g. an opaque "Bearer <token>" or an LDAP-DN claim) but reject CR/LF,
|
||||
// bidi, and over-length; claim values joined into delimited/identifier headers
|
||||
// (rank 33) additionally reject , ; =.
|
||||
func TestRank33And34_HeaderSanitizationDistinction(t *testing.T) {
|
||||
// Rank 34 — free-form header value.
|
||||
if headerValueReason("Bearer abc=def==", 8192) != "" {
|
||||
t.Error("'=' must be allowed in a free-form header value (opaque bearer token)")
|
||||
}
|
||||
if headerValueReason("cn=user,ou=eng;dc=x", 8192) != "" {
|
||||
t.Error("',;=' must be allowed in a free-form header value (e.g. an LDAP DN claim)")
|
||||
}
|
||||
if headerValueReason("evil"+string(rune(13))+string(rune(10))+"Injected: 1", 8192) == "" {
|
||||
t.Error("CR/LF must be rejected in a header value (injection)")
|
||||
}
|
||||
if headerValueReason("toolong", 3) == "" {
|
||||
t.Error("over-length value must be rejected")
|
||||
}
|
||||
|
||||
// Rank 33 — claim value bound for a delimited/identifier header.
|
||||
if _, ok := sanitizeHeaderClaimValue("admins,superadmins", 256); ok {
|
||||
t.Error("a comma must be rejected in a value joined into a comma-delimited header")
|
||||
}
|
||||
if _, ok := sanitizeHeaderClaimValue("normal-user@example.com", 256); !ok {
|
||||
t.Error("a clean identifier must pass claim sanitization")
|
||||
}
|
||||
if _, ok := sanitizeHeaderClaimValue("evil"+string(rune(13))+string(rune(10))+"X: 1", 256); ok {
|
||||
t.Error("CR/LF must be rejected in a claim value")
|
||||
}
|
||||
}
|
||||
@@ -1,590 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityEventType categorizes different types of security events
|
||||
// that can occur during OIDC authentication and authorization flows.
|
||||
type SecurityEventType string
|
||||
|
||||
// Security event types for monitoring and alerting
|
||||
const (
|
||||
// AuthFailure indicates a failed authentication attempt
|
||||
AuthFailure SecurityEventType = "authentication_failure"
|
||||
// TokenValidFailure indicates JWT token validation failed
|
||||
TokenValidFailure SecurityEventType = "token_validation_failure"
|
||||
// RateLimitHit indicates rate limiting was triggered
|
||||
RateLimitHit SecurityEventType = "rate_limit_hit"
|
||||
// SuspiciousActivity indicates potentially malicious behavior
|
||||
SuspiciousActivity SecurityEventType = "suspicious_activity"
|
||||
)
|
||||
|
||||
// DefaultSeverity returns the default severity level for each security event type.
|
||||
// Severity levels are: low, medium, high.
|
||||
func (t SecurityEventType) DefaultSeverity() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "medium"
|
||||
case TokenValidFailure:
|
||||
return "medium"
|
||||
case RateLimitHit:
|
||||
return "low"
|
||||
case SuspiciousActivity:
|
||||
return "high"
|
||||
default:
|
||||
return "medium"
|
||||
}
|
||||
}
|
||||
|
||||
// IPFailureType returns a string identifier for categorizing failures
|
||||
// by IP address for rate limiting and blocking decisions.
|
||||
func (t SecurityEventType) IPFailureType() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "auth_failure"
|
||||
case TokenValidFailure:
|
||||
return "token_failure"
|
||||
case SuspiciousActivity:
|
||||
return "suspicious"
|
||||
default:
|
||||
return "general"
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityEvent represents a security-related event with comprehensive context.
|
||||
// Contains timing information, IP address, user agent, request details,
|
||||
// and custom event-specific data for security analysis and alerting.
|
||||
type SecurityEvent struct {
|
||||
// Timestamp when the event occurred
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
// Details contains event-specific additional information
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
// Type categorizes the event (auth_failure, token_failure, etc.)
|
||||
Type string `json:"type"`
|
||||
// Severity indicates event importance (low, medium, high)
|
||||
Severity string `json:"severity"`
|
||||
// ClientIP is the source IP address of the request
|
||||
ClientIP string `json:"client_ip"`
|
||||
// UserAgent is the User-Agent header from the request
|
||||
UserAgent string `json:"user_agent"`
|
||||
// RequestPath is the requested URL path
|
||||
RequestPath string `json:"request_path"`
|
||||
// Message provides human-readable description of the event
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SecurityMonitor provides comprehensive security monitoring for the OIDC middleware.
|
||||
// It tracks failures by IP address, detects suspicious patterns, enforces
|
||||
// rate limits, and can trigger custom security event handlers.
|
||||
type SecurityMonitor struct {
|
||||
ipFailures map[string]*IPFailureTracker
|
||||
patternDetector *SuspiciousPatternDetector
|
||||
logger *Logger
|
||||
cleanupTask *BackgroundTask
|
||||
eventHandlers []SecurityEventHandler
|
||||
config SecurityMonitorConfig
|
||||
ipMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// IPFailureTracker maintains failure statistics and blocking state for an IP address.
|
||||
// Used for implementing progressive penalties and automatic IP blocking based on
|
||||
// failure patterns, with support for different failure types for
|
||||
// rate limiting and IP blocking decisions.
|
||||
type IPFailureTracker struct {
|
||||
// LastFailure timestamp of the most recent failure
|
||||
LastFailure time.Time
|
||||
// FirstFailure timestamp of the first failure in current window
|
||||
FirstFailure time.Time
|
||||
// BlockedUntil indicates when the IP block expires
|
||||
BlockedUntil time.Time
|
||||
// FailureTypes tracks counts by failure type
|
||||
FailureTypes map[string]int64
|
||||
// FailureCount total number of failures
|
||||
FailureCount int64
|
||||
// mutex protects concurrent access to tracker data
|
||||
mutex sync.RWMutex
|
||||
// IsBlocked indicates if this IP is currently blocked
|
||||
IsBlocked bool
|
||||
}
|
||||
|
||||
// SuspiciousPatternDetector identifies attack patterns that may indicate coordinated threats.
|
||||
// Analyzes events across multiple time windows to detect rapid failures, distributed attacks,
|
||||
// and persistent attack patterns that individual IP monitoring might miss.
|
||||
type SuspiciousPatternDetector struct {
|
||||
// recentEvents stores recent security events for analysis
|
||||
recentEvents []SecurityEvent
|
||||
// shortWindow defines time frame for rapid failure detection
|
||||
shortWindow time.Duration
|
||||
// mediumWindow defines time frame for distributed attack detection
|
||||
mediumWindow time.Duration
|
||||
// longWindow defines time frame for persistent attack detection
|
||||
longWindow time.Duration
|
||||
// rapidFailureThreshold triggers rapid failure alerts
|
||||
rapidFailureThreshold int
|
||||
// distributedAttackThreshold triggers distributed attack alerts
|
||||
distributedAttackThreshold int
|
||||
// persistentAttackThreshold triggers persistent attack alerts
|
||||
persistentAttackThreshold int
|
||||
// eventsMutex protects concurrent access to events
|
||||
eventsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// SecurityEventHandler defines the interface for processing security events.
|
||||
// Implementations can log events, send alerts, update external systems,
|
||||
// or trigger automated response actions.
|
||||
type SecurityEventHandler interface {
|
||||
// HandleSecurityEvent processes a security event
|
||||
HandleSecurityEvent(event SecurityEvent)
|
||||
}
|
||||
|
||||
// SecurityMonitorConfig contains configuration parameters for the security monitor.
|
||||
// Controls thresholds, time windows, and behavior for security monitoring.
|
||||
type SecurityMonitorConfig struct {
|
||||
// MaxFailuresPerIP sets the failure threshold before blocking
|
||||
MaxFailuresPerIP int `json:"max_failures_per_ip"`
|
||||
// FailureWindowMinutes defines the time window for counting failures
|
||||
FailureWindowMinutes int `json:"failure_window_minutes"`
|
||||
// BlockDurationMinutes sets how long to block an IP
|
||||
BlockDurationMinutes int `json:"block_duration_minutes"`
|
||||
// RapidFailureThreshold triggers rapid failure detection
|
||||
RapidFailureThreshold int `json:"rapid_failure_threshold"`
|
||||
// CleanupIntervalMinutes sets cleanup frequency for old data
|
||||
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
|
||||
RetentionHours int `json:"retention_hours"`
|
||||
EnablePatternDetection bool `json:"enable_pattern_detection"`
|
||||
EnableDetailedLogging bool `json:"enable_detailed_logging"`
|
||||
LogSuspiciousOnly bool `json:"log_suspicious_only"`
|
||||
}
|
||||
|
||||
// DefaultSecurityMonitorConfig returns a default configuration
|
||||
func DefaultSecurityMonitorConfig() SecurityMonitorConfig {
|
||||
return SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 10,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 60,
|
||||
EnablePatternDetection: true,
|
||||
RapidFailureThreshold: 5,
|
||||
EnableDetailedLogging: true,
|
||||
LogSuspiciousOnly: false,
|
||||
CleanupIntervalMinutes: 30,
|
||||
RetentionHours: 24,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSecurityMonitor creates a new security monitor instance
|
||||
func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor {
|
||||
sm := &SecurityMonitor{
|
||||
ipFailures: make(map[string]*IPFailureTracker),
|
||||
eventHandlers: make([]SecurityEventHandler, 0),
|
||||
config: config,
|
||||
logger: logger,
|
||||
patternDetector: NewSuspiciousPatternDetector(),
|
||||
}
|
||||
|
||||
sm.startCleanupRoutine()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// NewSuspiciousPatternDetector creates a new pattern detector
|
||||
func NewSuspiciousPatternDetector() *SuspiciousPatternDetector {
|
||||
return &SuspiciousPatternDetector{
|
||||
shortWindow: 1 * time.Minute,
|
||||
mediumWindow: 5 * time.Minute,
|
||||
longWindow: 15 * time.Minute,
|
||||
rapidFailureThreshold: 5,
|
||||
distributedAttackThreshold: 20,
|
||||
persistentAttackThreshold: 50,
|
||||
recentEvents: make([]SecurityEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSecurityEvent is a generic method to record any type of security event
|
||||
func (sm *SecurityMonitor) RecordSecurityEvent(
|
||||
eventType SecurityEventType,
|
||||
clientIP, userAgent, requestPath string,
|
||||
message string,
|
||||
details map[string]interface{},
|
||||
trackIPFailure bool) {
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: string(eventType),
|
||||
Severity: eventType.DefaultSeverity(),
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
RequestPath: requestPath,
|
||||
Message: message,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
if trackIPFailure {
|
||||
sm.recordIPFailure(clientIP, eventType.IPFailureType())
|
||||
}
|
||||
|
||||
sm.processSecurityEvent(event)
|
||||
}
|
||||
|
||||
// RecordAuthenticationFailure records an authentication failure event
|
||||
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
|
||||
if details == nil {
|
||||
details = make(map[string]interface{})
|
||||
}
|
||||
details["reason"] = reason
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
AuthFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Authentication failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordTokenValidationFailure records a token validation failure
|
||||
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
|
||||
details := map[string]interface{}{
|
||||
"reason": reason,
|
||||
}
|
||||
if tokenPrefix != "" {
|
||||
details["token_prefix"] = tokenPrefix
|
||||
}
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
TokenValidFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Token validation failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordRateLimitHit records when rate limiting is triggered
|
||||
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
|
||||
details := map[string]interface{}{
|
||||
"limit_type": "token_verification",
|
||||
}
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
RateLimitHit,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
"Rate limit exceeded",
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
|
||||
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) {
|
||||
if details == nil {
|
||||
details = make(map[string]interface{})
|
||||
}
|
||||
details["activity_type"] = activityType
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
SuspiciousActivity,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// recordIPFailure tracks failures for a specific IP address
|
||||
func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) {
|
||||
sm.ipMutex.Lock()
|
||||
defer sm.ipMutex.Unlock()
|
||||
|
||||
tracker, exists := sm.ipFailures[clientIP]
|
||||
if !exists {
|
||||
tracker = &IPFailureTracker{
|
||||
FailureTypes: make(map[string]int64),
|
||||
FirstFailure: time.Now(),
|
||||
}
|
||||
sm.ipFailures[clientIP] = tracker
|
||||
}
|
||||
|
||||
tracker.mutex.Lock()
|
||||
defer tracker.mutex.Unlock()
|
||||
|
||||
tracker.FailureCount++
|
||||
tracker.LastFailure = time.Now()
|
||||
tracker.FailureTypes[failureType]++
|
||||
|
||||
windowStart := time.Now().Add(-time.Duration(sm.config.FailureWindowMinutes) * time.Minute)
|
||||
if tracker.FirstFailure.After(windowStart) && tracker.FailureCount >= int64(sm.config.MaxFailuresPerIP) {
|
||||
if !tracker.IsBlocked {
|
||||
tracker.IsBlocked = true
|
||||
tracker.BlockedUntil = time.Now().Add(time.Duration(sm.config.BlockDurationMinutes) * time.Minute)
|
||||
|
||||
sm.logger.Errorf("IP %s blocked due to %d failures (types: %v)", clientIP, tracker.FailureCount, tracker.FailureTypes)
|
||||
|
||||
blockEvent := SecurityEvent{
|
||||
Type: "ip_blocked",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes),
|
||||
Details: map[string]interface{}{
|
||||
"failure_count": tracker.FailureCount,
|
||||
"failure_types": tracker.FailureTypes,
|
||||
"blocked_until": tracker.BlockedUntil,
|
||||
},
|
||||
}
|
||||
sm.processSecurityEvent(blockEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsIPBlocked checks if an IP address is currently blocked
|
||||
func (sm *SecurityMonitor) IsIPBlocked(clientIP string) bool {
|
||||
sm.ipMutex.RLock()
|
||||
defer sm.ipMutex.RUnlock()
|
||||
|
||||
tracker, exists := sm.ipFailures[clientIP]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
defer tracker.mutex.RUnlock()
|
||||
|
||||
if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) {
|
||||
return true
|
||||
}
|
||||
|
||||
if tracker.IsBlocked && time.Now().After(tracker.BlockedUntil) {
|
||||
tracker.IsBlocked = false
|
||||
sm.logger.Infof("IP %s automatically unblocked", clientIP)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// processSecurityEvent processes a security event through all handlers and pattern detection
|
||||
func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) {
|
||||
if sm.config.EnablePatternDetection {
|
||||
sm.patternDetector.AddEvent(event)
|
||||
|
||||
if patterns := sm.patternDetector.DetectSuspiciousPatterns(); len(patterns) > 0 {
|
||||
if len(patterns) == 1 {
|
||||
sm.logger.Errorf("Suspicious pattern detected: %s", patterns[0])
|
||||
} else {
|
||||
sm.logger.Errorf("Multiple suspicious patterns detected: %v", patterns)
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
patternEvent := SecurityEvent{
|
||||
Type: "suspicious_pattern",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern),
|
||||
Details: map[string]interface{}{
|
||||
"pattern_type": pattern,
|
||||
"trigger_event": event,
|
||||
},
|
||||
}
|
||||
sm.handleSecurityEvent(patternEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sm.handleSecurityEvent(event)
|
||||
}
|
||||
|
||||
// handleSecurityEvent sends the event to all registered handlers
|
||||
func (sm *SecurityMonitor) handleSecurityEvent(event SecurityEvent) {
|
||||
if sm.config.EnableDetailedLogging && (!sm.config.LogSuspiciousOnly || event.Severity == "high") {
|
||||
sm.logger.Infof("Security Event [%s/%s]: %s (IP: %s, Path: %s)",
|
||||
event.Type, event.Severity, event.Message, event.ClientIP, event.RequestPath)
|
||||
}
|
||||
|
||||
for _, handler := range sm.eventHandlers {
|
||||
go handler.HandleSecurityEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// AddEventHandler adds a security event handler
|
||||
func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
|
||||
sm.eventHandlers = append(sm.eventHandlers, handler)
|
||||
}
|
||||
|
||||
// This is kept for API compatibility but doesn't collect actual metrics
|
||||
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"tracked_ips": 0,
|
||||
}
|
||||
}
|
||||
|
||||
// AddEvent adds an event to the pattern detector
|
||||
func (spd *SuspiciousPatternDetector) AddEvent(event SecurityEvent) {
|
||||
spd.eventsMutex.Lock()
|
||||
defer spd.eventsMutex.Unlock()
|
||||
|
||||
spd.recentEvents = append(spd.recentEvents, event)
|
||||
|
||||
cutoff := time.Now().Add(-spd.longWindow)
|
||||
var filteredEvents []SecurityEvent
|
||||
for _, e := range spd.recentEvents {
|
||||
if e.Timestamp.After(cutoff) {
|
||||
filteredEvents = append(filteredEvents, e)
|
||||
}
|
||||
}
|
||||
spd.recentEvents = filteredEvents
|
||||
}
|
||||
|
||||
// DetectSuspiciousPatterns analyzes recent events for suspicious patterns
|
||||
func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string {
|
||||
spd.eventsMutex.RLock()
|
||||
defer spd.eventsMutex.RUnlock()
|
||||
|
||||
var patterns []string
|
||||
now := time.Now()
|
||||
|
||||
ipCounts := make(map[string]int)
|
||||
shortWindowStart := now.Add(-spd.shortWindow)
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(shortWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
ipCounts[event.ClientIP]++
|
||||
}
|
||||
}
|
||||
|
||||
for ip, count := range ipCounts {
|
||||
if count >= spd.rapidFailureThreshold {
|
||||
patterns = append(patterns, fmt.Sprintf("rapid_failures_from_ip_%s", ip))
|
||||
}
|
||||
}
|
||||
|
||||
mediumWindowStart := now.Add(-spd.mediumWindow)
|
||||
uniqueFailingIPs := make(map[string]bool)
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(mediumWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
uniqueFailingIPs[event.ClientIP] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(uniqueFailingIPs) >= spd.distributedAttackThreshold {
|
||||
patterns = append(patterns, "distributed_attack_pattern")
|
||||
}
|
||||
|
||||
longWindowStart := now.Add(-spd.longWindow)
|
||||
persistentFailures := 0
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(longWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
persistentFailures++
|
||||
}
|
||||
}
|
||||
|
||||
if persistentFailures >= spd.persistentAttackThreshold {
|
||||
patterns = append(patterns, "persistent_attack_pattern")
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
// startCleanupRoutine starts the background cleanup routine
|
||||
func (sm *SecurityMonitor) startCleanupRoutine() {
|
||||
sm.cleanupTask = NewBackgroundTask(
|
||||
"security-monitor-cleanup",
|
||||
time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute,
|
||||
sm.cleanup,
|
||||
sm.logger)
|
||||
sm.cleanupTask.Start()
|
||||
}
|
||||
|
||||
// StopCleanupRoutine stops the background cleanup routine
|
||||
func (sm *SecurityMonitor) StopCleanupRoutine() {
|
||||
if sm.cleanupTask != nil {
|
||||
sm.cleanupTask.Stop()
|
||||
sm.cleanupTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes old tracking data
|
||||
func (sm *SecurityMonitor) cleanup() {
|
||||
sm.ipMutex.Lock()
|
||||
defer sm.ipMutex.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-time.Duration(sm.config.RetentionHours) * time.Hour)
|
||||
|
||||
for ip, tracker := range sm.ipFailures {
|
||||
tracker.mutex.RLock()
|
||||
shouldRemove := tracker.LastFailure.Before(cutoff) && !tracker.IsBlocked
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
if shouldRemove {
|
||||
delete(sm.ipFailures, ip)
|
||||
}
|
||||
}
|
||||
|
||||
sm.logger.Debugf("Security monitor cleanup completed, tracking %d IPs", len(sm.ipFailures))
|
||||
}
|
||||
|
||||
// ExtractClientIP extracts the client IP from the request, considering proxy headers
|
||||
func ExtractClientIP(r *http.Request) string {
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
if net.ParseIP(xri) != nil {
|
||||
return xri
|
||||
}
|
||||
}
|
||||
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
ip := strings.TrimSpace(ips[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// LoggingSecurityEventHandler logs security events to the standard logger
|
||||
type LoggingSecurityEventHandler struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewLoggingSecurityEventHandler creates a new logging event handler
|
||||
func NewLoggingSecurityEventHandler(logger *Logger) *LoggingSecurityEventHandler {
|
||||
return &LoggingSecurityEventHandler{logger: logger}
|
||||
}
|
||||
|
||||
// HandleSecurityEvent implements SecurityEventHandler
|
||||
func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
switch event.Severity {
|
||||
case "high":
|
||||
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
case "medium":
|
||||
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
case "low":
|
||||
h.logger.Infof("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
default:
|
||||
h.logger.Debugf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
}
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSecurityMonitor(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
config.MaxFailuresPerIP = 3
|
||||
config.BlockDurationMinutes = 1 // 1 minute for testing
|
||||
config.CleanupIntervalMinutes = 1
|
||||
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
defer func() {
|
||||
// Allow cleanup goroutine to finish
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}()
|
||||
|
||||
t.Run("Record authentication failure", func(t *testing.T) {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "invalid credentials", nil)
|
||||
|
||||
// Should not be blocked after first failure
|
||||
if monitor.IsIPBlocked("192.168.1.1") {
|
||||
t.Error("IP should not be blocked after first failure")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IP blocked after max failures", func(t *testing.T) {
|
||||
// Record multiple failures
|
||||
for i := 0; i < config.MaxFailuresPerIP; i++ {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.2", "test-agent", "/login", "invalid credentials", nil)
|
||||
}
|
||||
|
||||
// Should be blocked now
|
||||
if !monitor.IsIPBlocked("192.168.1.2") {
|
||||
t.Error("IP should be blocked after max failures")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Token validation failure", func(t *testing.T) {
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordTokenValidationFailure("192.168.1.3", "test-agent", "/api", "invalid token", "abc123")
|
||||
})
|
||||
|
||||
t.Run("Rate limit hit", func(t *testing.T) {
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordRateLimitHit("192.168.1.4", "test-agent", "/api")
|
||||
})
|
||||
|
||||
t.Run("Suspicious activity", func(t *testing.T) {
|
||||
details := map[string]interface{}{"pattern": "unusual"}
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
detector := NewSuspiciousPatternDetector()
|
||||
|
||||
t.Run("Add events and detect patterns", func(t *testing.T) {
|
||||
// Add multiple events from same IP
|
||||
for i := 0; i < 10; i++ {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.100",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
detector.AddEvent(event)
|
||||
}
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := false
|
||||
for _, p := range patterns {
|
||||
if p == "rapid_failures_from_ip_192.168.1.100" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected to detect rapid failure pattern")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Detect distributed attack pattern", func(t *testing.T) {
|
||||
// Add failures from many different IPs
|
||||
for i := 0; i < 25; i++ {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1." + strconv.Itoa(100+i),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
detector.AddEvent(event)
|
||||
}
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := false
|
||||
for _, p := range patterns {
|
||||
if p == "distributed_attack_pattern" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected to detect distributed attack pattern")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
headers map[string]string
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "Direct connection",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For header",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 10.0.0.1"},
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP header",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{"X-Real-IP": "203.0.113.2"},
|
||||
expectedIP: "203.0.113.2",
|
||||
},
|
||||
{
|
||||
name: "Multiple headers - X-Real-IP takes precedence",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "203.0.113.1",
|
||||
"X-Real-IP": "203.0.113.2",
|
||||
},
|
||||
expectedIP: "203.0.113.2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
|
||||
for key, value := range tt.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
ip := ExtractClientIP(req)
|
||||
if ip != tt.expectedIP {
|
||||
t.Errorf("Expected IP %s, got %s", tt.expectedIP, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityEventHandlers(t *testing.T) {
|
||||
t.Run("Logging security event handler", func(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
handler := NewLoggingSecurityEventHandler(logger)
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.1",
|
||||
Timestamp: time.Now(),
|
||||
Message: "Test failure",
|
||||
Severity: "medium",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
handler.HandleSecurityEvent(event)
|
||||
})
|
||||
|
||||
// Metrics security event handler test removed as part of metrics cleanup
|
||||
}
|
||||
|
||||
func TestSecurityMonitorEventHandlers(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Add event handler with proper synchronization
|
||||
handlerCalled := make(chan bool, 1)
|
||||
handler := &testSecurityEventHandler{
|
||||
callback: func(event SecurityEvent) {
|
||||
select {
|
||||
case handlerCalled <- true:
|
||||
default:
|
||||
// Channel already has a value, don't block
|
||||
}
|
||||
},
|
||||
}
|
||||
monitor.AddEventHandler(handler)
|
||||
|
||||
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "test failure", nil)
|
||||
|
||||
// Wait for event handler to be called with timeout
|
||||
select {
|
||||
case <-handlerCalled:
|
||||
// Success - handler was called
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Expected event handler to be called within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper for security event handler
|
||||
type testSecurityEventHandler struct {
|
||||
callback func(SecurityEvent)
|
||||
}
|
||||
|
||||
func (h *testSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
h.callback(event)
|
||||
}
|
||||
|
||||
func TestDefaultSecurityMonitorConfig(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
|
||||
if config.MaxFailuresPerIP <= 0 {
|
||||
t.Error("Expected positive MaxFailuresPerIP")
|
||||
}
|
||||
if config.BlockDurationMinutes <= 0 {
|
||||
t.Error("Expected positive BlockDurationMinutes")
|
||||
}
|
||||
if config.CleanupIntervalMinutes <= 0 {
|
||||
t.Error("Expected positive CleanupIntervalMinutes")
|
||||
}
|
||||
if config.FailureWindowMinutes <= 0 {
|
||||
t.Error("Expected positive FailureWindowMinutes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityMonitorCleanup(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
config.CleanupIntervalMinutes = 1
|
||||
config.BlockDurationMinutes = 1
|
||||
config.RetentionHours = 1
|
||||
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Block an IP
|
||||
for i := 0; i < config.MaxFailuresPerIP; i++ {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.99", "test-agent", "/login", "test", nil)
|
||||
}
|
||||
|
||||
// Verify it's blocked
|
||||
if !monitor.IsIPBlocked("192.168.1.99") {
|
||||
t.Error("IP should be blocked")
|
||||
}
|
||||
|
||||
// Wait a bit and check if it gets unblocked automatically
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// The IP should still be blocked since we haven't waited long enough
|
||||
if !monitor.IsIPBlocked("192.168.1.99") {
|
||||
t.Error("IP should still be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityEventTypes(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Test different event types - just verify they don't panic
|
||||
monitor.RecordAuthenticationFailure("192.168.1.200", "test-agent", "/login", "invalid password", nil)
|
||||
monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123")
|
||||
monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api")
|
||||
|
||||
details := map[string]interface{}{"pattern": "test"}
|
||||
monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details)
|
||||
|
||||
// Just verify GetSecurityMetrics doesn't panic
|
||||
_ = monitor.GetSecurityMetrics()
|
||||
}
|
||||
+55
-1
@@ -4,7 +4,9 @@ import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
@@ -31,6 +33,45 @@ func constantTimeStringCompare(a, b string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||||
}
|
||||
|
||||
// deriveCookieKeys derives an independent 64-byte HMAC authentication key and a
|
||||
// 32-byte AES-256 encryption key from the operator-provided session encryption
|
||||
// key using HKDF-SHA256 (RFC 5869).
|
||||
//
|
||||
// gorilla/securecookie only ENCRYPTS the cookie payload when a block
|
||||
// (encryption) key is supplied; constructing the store with a single key leaves
|
||||
// sessions signed-but-plaintext, so the OIDC access/refresh/ID tokens stored in
|
||||
// the cookie are recoverable by anyone who can read the raw cookie bytes. Two
|
||||
// independent keys are derived here so the cookie is both encrypted and
|
||||
// authenticated. HKDF is implemented with stdlib hmac+sha256 so it runs under
|
||||
// Traefik's yaegi interpreter, which may not export crypto/hkdf.
|
||||
func deriveCookieKeys(secret string) (authKey, encKey []byte) {
|
||||
okm := hkdfSHA256([]byte(secret), nil, []byte("traefikoidc session cookie keys v1"), 96)
|
||||
return okm[:64], okm[64:96]
|
||||
}
|
||||
|
||||
// hkdfSHA256 performs HKDF-Extract followed by HKDF-Expand (RFC 5869) using
|
||||
// HMAC-SHA256 and returns length bytes of output keying material.
|
||||
func hkdfSHA256(ikm, salt, info []byte, length int) []byte {
|
||||
if len(salt) == 0 {
|
||||
salt = make([]byte, sha256.Size)
|
||||
}
|
||||
// Extract: PRK = HMAC-SHA256(salt, IKM)
|
||||
ext := hmac.New(sha256.New, salt)
|
||||
ext.Write(ikm)
|
||||
prk := ext.Sum(nil)
|
||||
// Expand: T(i) = HMAC-SHA256(PRK, T(i-1) | info | i)
|
||||
var out, t []byte
|
||||
for i := byte(1); len(out) < length; i++ {
|
||||
exp := hmac.New(sha256.New, prk)
|
||||
exp.Write(t)
|
||||
exp.Write(info)
|
||||
exp.Write([]byte{i})
|
||||
t = exp.Sum(nil)
|
||||
out = append(out, t...)
|
||||
}
|
||||
return out[:length]
|
||||
}
|
||||
|
||||
// min returns the minimum of two integers.
|
||||
// This is a utility function used throughout the session management code.
|
||||
// Parameters:
|
||||
@@ -423,8 +464,13 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Derive independent authentication + encryption keys so the session cookie
|
||||
// is AES-256 encrypted and HMAC authenticated, not merely signed. See
|
||||
// deriveCookieKeys: a single key would leave the stored tokens in plaintext.
|
||||
authKey, encKey := deriveCookieKeys(encryptionKey)
|
||||
|
||||
sm := &SessionManager{
|
||||
store: sessions.NewCookieStore([]byte(encryptionKey)),
|
||||
store: sessions.NewCookieStore(authKey, encKey),
|
||||
forceHTTPS: forceHTTPS,
|
||||
cookieDomain: cookieDomain,
|
||||
cookiePrefix: cookiePrefix,
|
||||
@@ -435,6 +481,14 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Bind the cookie codec's timestamp validity (and the cookie Max-Age) to the
|
||||
// configured session lifetime instead of gorilla's 30-day default, so a
|
||||
// stolen cookie is not cryptographically valid for up to 30 days regardless
|
||||
// of the (possibly much shorter) configured sessionMaxAge (rank 13).
|
||||
if cs, ok := sm.store.(*sessions.CookieStore); ok {
|
||||
cs.MaxAge(int(sessionMaxAge.Seconds()))
|
||||
}
|
||||
|
||||
// Initialize global memory monitoring (singleton)
|
||||
sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger)
|
||||
|
||||
|
||||
+27
-1
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -762,7 +763,32 @@ func validateTemplateSecure(templateStr string) error {
|
||||
// Returns true if the URL is valid and secure (HTTPS), false otherwise.
|
||||
func isValidSecureURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme == "https" && u.Host != ""
|
||||
if err != nil || u.Host == "" {
|
||||
return false
|
||||
}
|
||||
if u.Scheme == "https" {
|
||||
return true
|
||||
}
|
||||
// Permit plaintext HTTP only for loopback hosts (local development,
|
||||
// in-cluster sidecar providers, tests). Loopback traffic never leaves the
|
||||
// host, so it is not exposed to network MITM; remote providers must use
|
||||
// HTTPS. Mirrors the RFC 8252 loopback allowance.
|
||||
if u.Scheme == "http" && isLoopbackHost(u.Hostname()) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isLoopbackHost reports whether host is "localhost" or a loopback IP literal
|
||||
// (127.0.0.0/8 or ::1).
|
||||
func isLoopbackHost(host string) bool {
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
|
||||
|
||||
+32
-2
@@ -106,8 +106,9 @@ func (rm *ResourceManager) GetCache(key string) interface{} {
|
||||
case "jwk-cache":
|
||||
cache = cacheManager.GetSharedJWKCache()
|
||||
default:
|
||||
// Generic cache implementation
|
||||
cache = NewGenericCache(1*time.Hour, rm.logger)
|
||||
// Generic cache implementation; bind cleanup goroutine to the manager's
|
||||
// shutdown channel so it exits when the ResourceManager shuts down.
|
||||
cache = newGenericCacheWithOwner(1*time.Hour, rm.logger, rm.shutdownChan)
|
||||
}
|
||||
|
||||
rm.caches[key] = cache
|
||||
@@ -263,6 +264,19 @@ func (rm *ResourceManager) cleanupInstance(instanceID string) {
|
||||
// This is a hook for future instance-specific cleanup needs
|
||||
}
|
||||
|
||||
// liveInstanceCount tracks the number of fully-constructed TraefikOidc plugin
|
||||
// instances alive in this process. Process-global singleton tasks (such as the
|
||||
// shared token-cleanup) must only be stopped when the LAST instance shuts down,
|
||||
// otherwise one instance's teardown would disable them for all survivors.
|
||||
var liveInstanceCount int32
|
||||
|
||||
// registerLiveInstance records a newly constructed plugin instance.
|
||||
func registerLiveInstance() { atomic.AddInt32(&liveInstanceCount, 1) }
|
||||
|
||||
// unregisterLiveInstance records a plugin instance shutting down and returns the
|
||||
// number of instances still alive afterwards.
|
||||
func unregisterLiveInstance() int32 { return atomic.AddInt32(&liveInstanceCount, -1) }
|
||||
|
||||
// Shutdown gracefully shuts down all managed resources
|
||||
func (rm *ResourceManager) Shutdown(ctx context.Context) error {
|
||||
var err error
|
||||
@@ -502,6 +516,9 @@ func (p *GoroutinePool) Shutdown(ctx context.Context) error {
|
||||
// GenericCache provides a simple cache implementation for testing
|
||||
type GenericCache struct {
|
||||
data map[string]interface{}
|
||||
// ownerStopChan, when non-nil, signals the cleanup goroutine to exit when
|
||||
// the owning ResourceManager shuts down, so the goroutine cannot outlive it.
|
||||
ownerStopChan <-chan struct{}
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
ttl time.Duration
|
||||
@@ -510,11 +527,19 @@ type GenericCache struct {
|
||||
|
||||
// NewGenericCache creates a new generic cache
|
||||
func NewGenericCache(ttl time.Duration, logger *Logger) *GenericCache {
|
||||
return newGenericCacheWithOwner(ttl, logger, nil)
|
||||
}
|
||||
|
||||
// newGenericCacheWithOwner creates a generic cache whose cleanup goroutine also
|
||||
// exits when ownerStopChan is closed (typically the ResourceManager shutdown
|
||||
// channel), guaranteeing the goroutine is stoppable on shutdown.
|
||||
func newGenericCacheWithOwner(ttl time.Duration, logger *Logger, ownerStopChan <-chan struct{}) *GenericCache {
|
||||
cache := &GenericCache{
|
||||
data: make(map[string]interface{}),
|
||||
ttl: ttl,
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
ownerStopChan: ownerStopChan,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
@@ -570,6 +595,11 @@ func (gc *GenericCache) cleanupRoutine() {
|
||||
gc.mu.Unlock()
|
||||
case <-gc.stopChan:
|
||||
return
|
||||
case <-gc.ownerStopChan:
|
||||
// Owning ResourceManager is shutting down; exit so the goroutine
|
||||
// does not outlive its owner. A nil channel blocks forever, so this
|
||||
// case is inert when no owner is set.
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -299,6 +299,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
plugin, err := NewWithContext(ctx, config, nil, "test")
|
||||
@@ -350,9 +353,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
configs := []Config{
|
||||
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
|
||||
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
|
||||
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
|
||||
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
|
||||
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
|
||||
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
|
||||
}
|
||||
|
||||
var plugins []*TraefikOidc
|
||||
@@ -435,6 +438,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
ProviderURL: mockServers[i].URL,
|
||||
ClientID: fmt.Sprintf("client%d", i),
|
||||
ClientSecret: fmt.Sprintf("secret%d", i),
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
plugin, err := NewWithContext(ctx, config, nil, fmt.Sprintf("test-%d", i))
|
||||
@@ -598,6 +604,9 @@ func TestBackwardCompatibility(t *testing.T) {
|
||||
ProviderURL: "https://example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
handler, err := New(context.Background(), nil, config, "test")
|
||||
@@ -620,6 +629,9 @@ func TestBackwardCompatibility(t *testing.T) {
|
||||
ProviderURL: "https://example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
handler, _ := New(context.Background(), nil, config, "test")
|
||||
|
||||
+23
-8
@@ -21,7 +21,10 @@ type IntrospectionResponse struct {
|
||||
Username string `json:"username,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Aud string `json:"aud,omitempty"`
|
||||
// Aud holds the introspection audience. Per RFC 7662 it may be a single
|
||||
// string or an array of strings, so it is decoded as interface{} and
|
||||
// matched with verifyAudience (which handles both shapes).
|
||||
Aud interface{} `json:"aud,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
Jti string `json:"jti,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
@@ -120,7 +123,7 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err
|
||||
|
||||
// Parse response per RFC 7662 Section 2.2
|
||||
var introspectionResp IntrospectionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&introspectionResp); err != nil {
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&introspectionResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode introspection response: %w", err)
|
||||
}
|
||||
|
||||
@@ -128,6 +131,12 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err
|
||||
if t.introspectionCache != nil {
|
||||
// Cache for a short duration or until token expiry (whichever is shorter)
|
||||
cacheDuration := 5 * time.Minute
|
||||
// When introspection is REQUIRED, operators expect near-real-time
|
||||
// revocation; cap the positive-result cache so a token revoked at the
|
||||
// provider cannot keep passing for the full 5 minutes (rank 8).
|
||||
if t.requireTokenIntrospection && cacheDuration > 30*time.Second {
|
||||
cacheDuration = 30 * time.Second
|
||||
}
|
||||
if introspectionResp.Exp > 0 {
|
||||
expTime := time.Unix(introspectionResp.Exp, 0)
|
||||
untilExp := time.Until(expTime)
|
||||
@@ -197,12 +206,18 @@ func (t *TraefikOidc) validateOpaqueToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate audience if configured
|
||||
// Note: For opaque tokens, audience validation via introspection may be limited
|
||||
// depending on what the introspection endpoint returns
|
||||
if t.audience != "" && t.audience != t.clientID && resp.Aud != "" {
|
||||
if resp.Aud != t.audience {
|
||||
return fmt.Errorf("invalid audience: expected %s, got %s", t.audience, resp.Aud)
|
||||
// Validate audience if configured. When a distinct API audience is
|
||||
// configured (audience != clientID), the introspection response MUST carry
|
||||
// a matching audience. Fail closed on a missing or mismatched aud: a token
|
||||
// whose audience cannot be confirmed must not be accepted, otherwise a
|
||||
// token minted for a different audience would pass. aud may be a single
|
||||
// string or an array of strings (RFC 7662); verifyAudience handles both.
|
||||
if t.audience != "" && t.audience != t.clientID {
|
||||
if resp.Aud == nil {
|
||||
return fmt.Errorf("invalid audience: expected %s, introspection response has no audience", t.audience)
|
||||
}
|
||||
if err := verifyAudience(resp.Aud, t.audience); err != nil {
|
||||
return fmt.Errorf("invalid audience: expected %s: %w", t.audience, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+9
-6
@@ -5,6 +5,8 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -212,11 +214,13 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
|
||||
//
|
||||
//nolint:gocognit,gocyclo // Complex token type detection with multiple provider-specific checks
|
||||
func (t *TraefikOidc) detectTokenType(jwt *JWT, token string) bool {
|
||||
// Use first 32 chars of token as cache key (sufficient for uniqueness)
|
||||
cacheKey := token
|
||||
if len(token) > 32 {
|
||||
cacheKey = token[:32]
|
||||
}
|
||||
// Key on a hash of the FULL token. The first 32 characters of a JWT are
|
||||
// only the base64url-encoded header, which is identical for every token
|
||||
// sharing the same alg+kid, so distinct tokens (e.g. an ID token and an
|
||||
// access token from the same issuer) would otherwise collide on the cache
|
||||
// key and be mis-classified.
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
cacheKey := hex.EncodeToString(sum[:])
|
||||
|
||||
// Check cache first
|
||||
if t.tokenTypeCache != nil {
|
||||
@@ -858,7 +862,6 @@ func (t *TraefikOidc) isAzureProvider() bool {
|
||||
strings.Contains(issuerURL, "login.windows.net")
|
||||
}
|
||||
|
||||
|
||||
// startTokenCleanup starts background cleanup goroutines for cache maintenance.
|
||||
// It runs periodic cleanup of token cache, JWK cache, and session chunks.
|
||||
// Includes panic recovery to ensure stability.
|
||||
|
||||
+5
-535
@@ -3,7 +3,9 @@ package traefikoidc
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -885,10 +887,9 @@ func TestDetectTokenTypeCaching(t *testing.T) {
|
||||
},
|
||||
}
|
||||
token := "test-token-for-caching-with-enough-characters-for-key"
|
||||
cacheKey := token
|
||||
if len(token) > 32 {
|
||||
cacheKey = token[:32]
|
||||
}
|
||||
// The cache key is a SHA-256 hash of the full token (collision-resistant).
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
cacheKey := hex.EncodeToString(sum[:])
|
||||
|
||||
result := tr.detectTokenType(jwt, token)
|
||||
if !result {
|
||||
@@ -911,521 +912,6 @@ func TestDetectTokenTypeCaching(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TOKEN VALIDATOR TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestNewTokenValidator(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
if validator == nil {
|
||||
t.Fatal("Expected non-nil token validator")
|
||||
}
|
||||
|
||||
if validator.logger == nil {
|
||||
t.Error("Expected logger to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTokenValidatorWithLogger(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
validator := NewTokenValidator(logger)
|
||||
|
||||
if validator == nil {
|
||||
t.Fatal("Expected non-nil token validator")
|
||||
}
|
||||
|
||||
if validator.logger != logger {
|
||||
t.Error("Expected provided logger to be used")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenEmpty(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
result := validator.ValidateToken("", false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for empty token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for empty token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "empty") {
|
||||
t.Errorf("Expected 'empty' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenRequireJWT(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
result := validator.ValidateToken("opaque_token_value_here", true)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for opaque token when JWT required")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error when JWT required but opaque token provided")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTValidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if !result.Valid {
|
||||
t.Errorf("Expected valid result, got error: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.TokenType != "JWT" {
|
||||
t.Errorf("Expected TokenType 'JWT', got %s", result.TokenType)
|
||||
}
|
||||
|
||||
if result.Claims == nil {
|
||||
t.Error("Expected claims to be parsed")
|
||||
}
|
||||
|
||||
if result.Expiry == nil {
|
||||
t.Error("Expected expiry to be extracted")
|
||||
}
|
||||
|
||||
if result.IssuedAt == nil {
|
||||
t.Error("Expected issued at to be extracted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTExpiredToken(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for expired token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for expired token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "expired") {
|
||||
t.Errorf("Expected 'expired' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTFutureIssuedAt(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(2 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(10 * time.Minute).Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for future iat")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for future iat")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "future") {
|
||||
t.Errorf("Expected 'future' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTNotBeforeClaim(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(2 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"nbf": time.Now().Add(1 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for nbf in future")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for nbf in future")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "not yet valid") {
|
||||
t.Errorf("Expected 'not yet valid' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTInvalidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"single part", "eyJhbGciOiJIUzI1NiJ9"},
|
||||
{"two parts", "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0In0"},
|
||||
{"four parts", "part1.part2.part3.part4"},
|
||||
{"empty part", "eyJhbGciOiJIUzI1NiJ9..signature"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateToken(tt.token, true)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for malformed JWT")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for malformed JWT")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenValid(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "sk_live_abcdef123456GHIJKL789"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if !result.Valid {
|
||||
t.Errorf("Expected valid result, got error: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.TokenType != "Opaque" {
|
||||
t.Errorf("Expected TokenType 'Opaque', got %s", result.TokenType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenTooShort(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "short"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for short token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for short token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "too short") {
|
||||
t.Errorf("Expected 'too short' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenWithSpaces(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "this token has spaces in it"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for token with spaces")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for token with spaces")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "spaces") {
|
||||
t.Errorf("Expected 'spaces' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenControlCharacters(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "token_with\x00control_char"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for token with control characters")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for token with control characters")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "control character") {
|
||||
t.Errorf("Expected 'control character' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenInsufficientEntropy(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "aaaaaabbbbbbccccccdddd"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for low entropy token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for low entropy token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "entropy") {
|
||||
t.Errorf("Expected 'entropy' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidBase64URL(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"valid uppercase", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", true},
|
||||
{"valid lowercase", "abcdefghijklmnopqrstuvwxyz", true},
|
||||
{"valid numbers", "0123456789", true},
|
||||
{"valid dash", "abc-def", true},
|
||||
{"valid underscore", "abc_def", true},
|
||||
{"valid equals", "abc=", true},
|
||||
{"invalid at sign", "abc@def", false},
|
||||
{"invalid space", "abc def", false},
|
||||
{"invalid plus", "abc+def", false},
|
||||
{"invalid slash", "abc/def", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.isValidBase64URL(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v for %s, got %v", tt.expected, tt.input, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTime(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
claim interface{}
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{name: "float64", claim: float64(1609459200), expected: true},
|
||||
{name: "int64", claim: int64(1609459200), expected: true},
|
||||
{name: "int", claim: int(1609459200), expected: true},
|
||||
{name: "string", claim: "not a timestamp", expected: false},
|
||||
{name: "nil", claim: nil, expected: false},
|
||||
{name: "map", claim: map[string]interface{}{}, expected: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.extractTime(tt.claim)
|
||||
|
||||
if tt.expected && result == nil {
|
||||
t.Error("Expected non-nil time")
|
||||
}
|
||||
|
||||
if !tt.expected && result != nil {
|
||||
t.Error("Expected nil time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenSize(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
maxSize int
|
||||
expectError bool
|
||||
}{
|
||||
{"within limit", "short_token", 20, false},
|
||||
{"at limit", "exactly_twenty_c", 16, false},
|
||||
{"exceeds limit", "this_token_is_too_long", 10, true},
|
||||
{"empty token", "", 10, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateTokenSize(tt.token, tt.maxSize)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error for oversized token")
|
||||
}
|
||||
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if err != nil && !strings.Contains(err.Error(), "exceeds") {
|
||||
t.Errorf("Expected 'exceeds' in error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaims(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(1609459200),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
extracted, err := validator.ExtractClaims(token)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extracted == nil {
|
||||
t.Fatal("Expected non-nil claims")
|
||||
}
|
||||
|
||||
if extracted["sub"] != "user123" {
|
||||
t.Errorf("Expected sub 'user123', got %v", extracted["sub"])
|
||||
}
|
||||
|
||||
if extracted["email"] != "user@example.com" {
|
||||
t.Errorf("Expected email 'user@example.com', got %v", extracted["email"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaimsInvalidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"single part", "onlyonepart"},
|
||||
{"two parts", "two.parts"},
|
||||
{"four parts", "one.two.three.four"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := validator.ExtractClaims(tt.token)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid format")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid JWT format") {
|
||||
t.Errorf("Expected 'invalid JWT format' in error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensEqual(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "secret_token_12345"
|
||||
token2 := "secret_token_12345"
|
||||
|
||||
if !validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensDifferent(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "secret_token_12345"
|
||||
token2 := "secret_token_54321"
|
||||
|
||||
if validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be different")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensDifferentLength(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "short"
|
||||
token2 := "much_longer_token"
|
||||
|
||||
if validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be different (different lengths)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensEmpty(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := ""
|
||||
token2 := ""
|
||||
|
||||
if !validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected empty tokens to be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenMaliciousPayloads(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"sql injection attempt", "'; DROP TABLE users; --"},
|
||||
{"xss attempt", "<script>alert('xss')</script>"},
|
||||
{"path traversal", "../../../etc/passwd"},
|
||||
{"null bytes", "token\x00with\x00nulls"},
|
||||
{"unicode exploit", "token\u0000\u0001\u0002"},
|
||||
{"extremely long", strings.Repeat("a", 100000)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateToken(tt.token, false)
|
||||
|
||||
if result.Valid {
|
||||
if result.Claims != nil {
|
||||
t.Logf("Token considered valid: %s", tt.name)
|
||||
}
|
||||
} else {
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for malicious payload")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONSOLIDATED TOKEN TESTS
|
||||
// =============================================================================
|
||||
@@ -2098,19 +1584,3 @@ func createTokenOfSize(baseToken string, targetSize int) string {
|
||||
|
||||
return baseToken
|
||||
}
|
||||
|
||||
func createTestJWTSimple(claims map[string]interface{}) string {
|
||||
header := map[string]interface{}{
|
||||
"alg": "HS256",
|
||||
"typ": "JWT",
|
||||
}
|
||||
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("fake_signature"))
|
||||
|
||||
return headerB64 + "." + claimsB64 + "." + signature
|
||||
}
|
||||
|
||||
+15
-2
@@ -149,7 +149,15 @@ func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bo
|
||||
if rs.idToken != "" {
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
return true, false, false
|
||||
// No ID token to corroborate an access token we cannot verify
|
||||
// (Azure nonce-bearing Graph access tokens carry a proprietary,
|
||||
// client-unverifiable signature). Do NOT authenticate on an
|
||||
// unverified token: refresh if a refresh token is available,
|
||||
// otherwise force re-authentication.
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +166,12 @@ func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bo
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return true, false, false
|
||||
// Opaque access token, no ID token to corroborate it, and
|
||||
// introspection was unavailable/disabled/errored (e.g.
|
||||
// circuit-breaker open). There is nothing left to verify the token
|
||||
// against, so fail closed and force re-authentication rather than
|
||||
// trusting an unverified opaque token.
|
||||
return false, false, true
|
||||
}
|
||||
if err := t.verifyToken(rs.idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/pool"
|
||||
)
|
||||
|
||||
// TokenValidator provides unified token validation functionality
|
||||
type TokenValidator struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewTokenValidator creates a new token validator
|
||||
func NewTokenValidator(logger *Logger) *TokenValidator {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
return &TokenValidator{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TokenValidationResult contains the result of token validation
|
||||
type TokenValidationResult struct {
|
||||
Error error
|
||||
Claims map[string]interface{}
|
||||
Expiry *time.Time
|
||||
IssuedAt *time.Time
|
||||
TokenType string
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// ValidateToken performs comprehensive token validation
|
||||
func (v *TokenValidator) ValidateToken(token string, requireJWT bool) TokenValidationResult {
|
||||
result := TokenValidationResult{}
|
||||
|
||||
// Basic validation
|
||||
if token == "" {
|
||||
result.Error = fmt.Errorf("token is empty")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check if it's a JWT or opaque token
|
||||
dotCount := strings.Count(token, ".")
|
||||
isJWT := dotCount == 2
|
||||
|
||||
if requireJWT && !isJWT {
|
||||
result.Error = fmt.Errorf("token is not a valid JWT (found %d dots, expected 2)", dotCount)
|
||||
return result
|
||||
}
|
||||
|
||||
if isJWT {
|
||||
return v.validateJWT(token)
|
||||
} else {
|
||||
return v.validateOpaqueToken(token)
|
||||
}
|
||||
}
|
||||
|
||||
// validateJWT validates a JWT token
|
||||
func (v *TokenValidator) validateJWT(token string) TokenValidationResult {
|
||||
result := TokenValidationResult{
|
||||
TokenType: "JWT",
|
||||
}
|
||||
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
result.Error = fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate each part
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
result.Error = fmt.Errorf("JWT part %d is empty", i)
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for valid base64url characters
|
||||
if !v.isValidBase64URL(part) {
|
||||
result.Error = fmt.Errorf("JWT part %d contains invalid base64url characters", i)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Decode and parse claims
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
result.Error = fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
pm := pool.Get()
|
||||
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
|
||||
defer pm.PutJSONDecoder(decoder)
|
||||
if err := decoder.Decode(&claims); err != nil {
|
||||
result.Error = fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
result.Claims = claims
|
||||
|
||||
// Extract standard claims
|
||||
if exp, ok := claims["exp"]; ok {
|
||||
expTime := v.extractTime(exp)
|
||||
if expTime != nil {
|
||||
result.Expiry = expTime
|
||||
// Check if expired
|
||||
if time.Now().After(*expTime) {
|
||||
result.Error = fmt.Errorf("token is expired (expired at %v)", expTime.Format(time.RFC3339))
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if iat, ok := claims["iat"]; ok {
|
||||
iatTime := v.extractTime(iat)
|
||||
if iatTime != nil {
|
||||
result.IssuedAt = iatTime
|
||||
// Check if issued in future
|
||||
if iatTime.After(time.Now().Add(5 * time.Minute)) {
|
||||
result.Error = fmt.Errorf("token issued in future (iat: %v)", iatTime.Format(time.RFC3339))
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check nbf (not before)
|
||||
if nbf, ok := claims["nbf"]; ok {
|
||||
nbfTime := v.extractTime(nbf)
|
||||
if nbfTime != nil && time.Now().Before(*nbfTime) {
|
||||
result.Error = fmt.Errorf("token not yet valid (nbf: %v)", nbfTime.Format(time.RFC3339))
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
result.Valid = true
|
||||
return result
|
||||
}
|
||||
|
||||
// validateOpaqueToken validates an opaque token
|
||||
func (v *TokenValidator) validateOpaqueToken(token string) TokenValidationResult {
|
||||
result := TokenValidationResult{
|
||||
TokenType: "Opaque",
|
||||
}
|
||||
|
||||
// Check minimum length
|
||||
if len(token) < 20 {
|
||||
result.Error = fmt.Errorf("opaque token too short (length: %d)", len(token))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for spaces
|
||||
if strings.Contains(token, " ") {
|
||||
result.Error = fmt.Errorf("opaque token contains spaces")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for control characters
|
||||
for i, char := range token {
|
||||
if char < 32 || char == 127 {
|
||||
result.Error = fmt.Errorf("opaque token contains control character at position %d", i)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Check entropy
|
||||
if len(token) >= 20 {
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, char := range token {
|
||||
uniqueChars[char] = true
|
||||
}
|
||||
if len(uniqueChars) < 8 {
|
||||
result.Error = fmt.Errorf("opaque token has insufficient entropy (unique chars: %d)", len(uniqueChars))
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
result.Valid = true
|
||||
return result
|
||||
}
|
||||
|
||||
// isValidBase64URL checks if a string contains only valid base64url characters
|
||||
func (v *TokenValidator) isValidBase64URL(s string) bool {
|
||||
for _, char := range s {
|
||||
if !((char >= 'A' && char <= 'Z') ||
|
||||
(char >= 'a' && char <= 'z') ||
|
||||
(char >= '0' && char <= '9') ||
|
||||
char == '-' || char == '_' || char == '=') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// extractTime extracts a time.Time from various claim formats
|
||||
func (v *TokenValidator) extractTime(claim interface{}) *time.Time {
|
||||
var timestamp int64
|
||||
|
||||
switch val := claim.(type) {
|
||||
case float64:
|
||||
timestamp = int64(val)
|
||||
case int64:
|
||||
timestamp = val
|
||||
case int:
|
||||
timestamp = int64(val)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
t := time.Unix(timestamp, 0)
|
||||
return &t
|
||||
}
|
||||
|
||||
// ValidateTokenSize checks if token size is within acceptable limits
|
||||
func (v *TokenValidator) ValidateTokenSize(token string, maxSize int) error {
|
||||
if len(token) > maxSize {
|
||||
return fmt.Errorf("token exceeds maximum size (size: %d, max: %d)", len(token), maxSize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractClaims extracts claims from a JWT without full validation
|
||||
func (v *TokenValidator) ExtractClaims(token string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
pm := pool.Get()
|
||||
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
|
||||
defer pm.PutJSONDecoder(decoder)
|
||||
if err := decoder.Decode(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// CompareTokens safely compares two tokens for equality
|
||||
func (v *TokenValidator) CompareTokens(token1, token2 string) bool {
|
||||
if len(token1) != len(token2) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use constant-time comparison to prevent timing attacks
|
||||
var result byte
|
||||
for i := 0; i < len(token1); i++ {
|
||||
result |= token1[i] ^ token2[i]
|
||||
}
|
||||
return result == 0
|
||||
}
|
||||
+21
-10
@@ -957,17 +957,28 @@ func (c *UniversalCache) updateLocalCache(key string, value interface{}, ttl tim
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
// Replace an existing entry in place: update the item and move its single
|
||||
// list node to the front. Without this, a repeat populate of the same key
|
||||
// (the per-request Get->backend-hit path) would PushFront a duplicate node
|
||||
// and overwrite c.items[key], orphaning the previous node. Orphans inflate
|
||||
// currentMemory/currentSize and, once eviction deletes the key, leave a
|
||||
// Back() node whose key is absent from c.items — so evictOldest() spins
|
||||
// while holding c.mu.Lock(): the 100%-CPU write-lock convoy seen in pprof.
|
||||
// setLocal dedups the same way; evictOldest also guards any dangling node.
|
||||
if existing, exists := c.items[key]; exists {
|
||||
c.currentMemory -= existing.Size
|
||||
c.lruList.Remove(existing.element)
|
||||
|
||||
// Replace any existing entry in place. Without this, a repeat populate of
|
||||
// the same key (the per-request Get->backend-hit path at line ~359)
|
||||
// PushFronts a second list node and overwrites c.items[key], orphaning the
|
||||
// previous node. Orphans inflate currentMemory/currentSize and — once the
|
||||
// eviction loop deletes the key — leave Back() nodes whose key is absent
|
||||
// from c.items, so evictOldest() no-ops while lruList.Len()>0 stays true:
|
||||
// an infinite loop while holding c.mu.Lock(), i.e. the 100%-CPU holder and
|
||||
// write-lock convoy. setLocal already dedups on this path; this mirrors it.
|
||||
if existing, ok := c.items[key]; ok {
|
||||
c.removeItem(key, existing)
|
||||
existing.Value = value
|
||||
existing.Size = size
|
||||
existing.ExpiresAt = now.Add(ttl)
|
||||
existing.LastAccessed = now
|
||||
existing.AccessCount++
|
||||
|
||||
existing.element = c.lruList.PushFront(key)
|
||||
c.currentMemory += size
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
item := &CacheItem{
|
||||
|
||||
+83
-1
@@ -19,7 +19,7 @@ import (
|
||||
// - true if the URL should be excluded from authentication, false otherwise.
|
||||
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
for excludedURL := range t.excludedURLs {
|
||||
if strings.HasPrefix(currentRequest, excludedURL) {
|
||||
if pathExcluded(currentRequest, excludedURL) {
|
||||
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
return true
|
||||
}
|
||||
@@ -27,6 +27,31 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// pathExcluded reports whether requestPath is covered by an excluded prefix at a
|
||||
// natural boundary: an exact match, a sub-path ("/public" → "/public/x"), or a
|
||||
// file extension ("/favicon" → "/favicon.ico"). It deliberately does NOT match
|
||||
// an unrelated sibling such as "/publicsecret", so a configured exclusion can no
|
||||
// longer be widened into an authentication bypass on a different resource.
|
||||
func pathExcluded(requestPath, excluded string) bool {
|
||||
excluded = strings.TrimRight(excluded, "/")
|
||||
if excluded == "" {
|
||||
// A "/" (root) exclusion only matches the root path, not everything.
|
||||
return requestPath == "" || requestPath == "/"
|
||||
}
|
||||
if requestPath == excluded {
|
||||
return true
|
||||
}
|
||||
if !strings.HasPrefix(requestPath, excluded) {
|
||||
return false
|
||||
}
|
||||
switch requestPath[len(excluded)] {
|
||||
case '/', '.':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// buildAuthURL constructs the OIDC provider authorization URL.
|
||||
// It builds the URL with all necessary parameters including client_id, scopes,
|
||||
// PKCE parameters, and provider-specific parameters for Google and Azure.
|
||||
@@ -288,6 +313,63 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDiscoveredEndpoint validates an endpoint URL obtained from the
|
||||
// provider's OIDC/OAuth2 discovery document before the plugin issues any
|
||||
// outbound request to it. A discovery document is attacker-influenced if the
|
||||
// provider is malicious or its TLS is broken, so an unvalidated endpoint is an
|
||||
// SSRF vector (e.g. jwks_uri or introspection_endpoint pointed at the cloud
|
||||
// metadata service 169.254.169.254 or an internal host).
|
||||
//
|
||||
// Empty endpoints are allowed (they are optional). Link-local (which covers the
|
||||
// 169.254.0.0/16 metadata range), multicast and unspecified addresses are
|
||||
// always rejected. Private addresses are rejected unless allowPrivateIPAddresses
|
||||
// is set. Loopback is rejected unless allowLoopback is true — which the caller
|
||||
// sets only when the operator-configured providerURL is itself loopback (local
|
||||
// development, in-cluster sidecars, tests), so production deployments pointed at
|
||||
// a real provider still block loopback SSRF.
|
||||
func (t *TraefikOidc) validateDiscoveredEndpoint(urlStr string, allowLoopback bool) error {
|
||||
if urlStr == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %w", err)
|
||||
}
|
||||
if u.Scheme != "https" && u.Scheme != "http" {
|
||||
return fmt.Errorf("disallowed URL scheme: %q", u.Scheme)
|
||||
}
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("missing host in URL")
|
||||
}
|
||||
if ip := net.ParseIP(u.Hostname()); ip != nil {
|
||||
switch {
|
||||
case ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsMulticast() || ip.IsUnspecified():
|
||||
return fmt.Errorf("endpoint host is a blocked address: %s", ip)
|
||||
case ip.IsLoopback() && !allowLoopback:
|
||||
return fmt.Errorf("endpoint host is a loopback address: %s", ip)
|
||||
case ip.IsPrivate() && !t.allowPrivateIPAddresses:
|
||||
return fmt.Errorf("endpoint host is a private address: %s", ip)
|
||||
}
|
||||
}
|
||||
if strings.Contains(u.Path, "..") {
|
||||
return fmt.Errorf("path traversal detected in URL path")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sameHost reports whether two URLs share the same host:port (case-insensitive).
|
||||
// Used to pin the credential-bearing introspection endpoint to the operator-
|
||||
// configured provider so a poisoned discovery document cannot redirect the
|
||||
// client secret to an attacker-controlled host.
|
||||
func sameHost(a, b string) bool {
|
||||
ua, erra := url.Parse(a)
|
||||
ub, errb := url.Parse(b)
|
||||
if erra != nil || errb != nil || ua.Host == "" || ub.Host == "" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(ua.Host, ub.Host)
|
||||
}
|
||||
|
||||
// validateHost validates a hostname or IP address for security.
|
||||
// It prevents access to localhost, private networks, and known metadata endpoints.
|
||||
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
|
||||
|
||||
+8
-3
@@ -135,7 +135,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
domain := parts[1]
|
||||
domain := strings.ToLower(parts[1])
|
||||
_, domainAllowed := t.allowedUserDomains[domain]
|
||||
|
||||
if domainAllowed {
|
||||
@@ -236,8 +236,13 @@ func (t *TraefikOidc) Close() error {
|
||||
// Get resource manager for cleanup
|
||||
rm := GetResourceManager()
|
||||
|
||||
// Stop singleton tasks related to this instance
|
||||
_ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup
|
||||
// singleton-token-cleanup is a process-global task shared by every plugin
|
||||
// instance. Only stop it when the LAST instance is shutting down;
|
||||
// otherwise one instance's teardown (e.g. a single config reload) would
|
||||
// kill chunked-session/token cleanup for all surviving instances (rank 12).
|
||||
if unregisterLiveInstance() <= 0 {
|
||||
_ = rm.StopBackgroundTask("singleton-token-cleanup") // best effort, last instance only
|
||||
}
|
||||
// Stop metadata refresh task using same hash-based name as startMetadataRefresh
|
||||
if t.providerURL != "" {
|
||||
hash := sha256.Sum256([]byte(t.providerURL))
|
||||
|
||||
Reference in New Issue
Block a user