mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 546ceb949c | |||
| f75b2f20e0 | |||
| cf6ed1da55 | |||
| f821b8829b |
@@ -18,6 +18,6 @@ jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
go-version: "1.25.x"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
|
||||
@@ -19,5 +19,5 @@ jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
go-version: "1.25.x"
|
||||
secrets: inherit
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# traefikoidc — Makefile
|
||||
# Run `make help` for available targets.
|
||||
|
||||
GO ?= go
|
||||
GOPATH := $(shell $(GO) env GOPATH)
|
||||
# Pin to the yaegi version bundled by the deployed Traefik so yaegi-validate
|
||||
# tests the real interpreter, not a newer one that may support more. Traefik
|
||||
# v3.7.1 vendors yaegi v0.16.1 (Go ~1.22 stdlib surface). Bump when Traefik is.
|
||||
YAEGI_VERSION ?= v0.16.1
|
||||
TEST_TIMEOUT ?= 480s
|
||||
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
.PHONY: help
|
||||
help: ## Show this help
|
||||
@grep -hE '^[a-zA-Z0-9_-]+:.*## ' $(MAKEFILE_LIST) | awk 'BEGIN{FS=":.*## "}{printf " \033[36m%-16s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
.PHONY: build
|
||||
build: ## Compile all packages (native toolchain)
|
||||
$(GO) build ./...
|
||||
|
||||
.PHONY: fmt
|
||||
fmt: ## Format sources with gofmt
|
||||
gofmt -w $$(git ls-files '*.go' | grep -v '^vendor/')
|
||||
|
||||
.PHONY: vet
|
||||
vet: ## Run go vet
|
||||
$(GO) vet ./...
|
||||
|
||||
.PHONY: lint
|
||||
lint: ## Run golangci-lint if available
|
||||
@command -v golangci-lint >/dev/null 2>&1 && golangci-lint run ./... || echo "golangci-lint not installed; skipping"
|
||||
|
||||
.PHONY: staticcheck
|
||||
staticcheck: ## Run staticcheck (matches the CI "Static Analysis" job; catches U1000 unused, etc.)
|
||||
@command -v staticcheck >/dev/null 2>&1 || { echo ">> installing staticcheck"; $(GO) install honnef.co/go/tools/cmd/staticcheck@latest; }
|
||||
@GOFLAGS=-buildvcs=false $$(command -v staticcheck || echo "$(GOPATH)/bin/staticcheck") ./...
|
||||
|
||||
.PHONY: test
|
||||
test: ## Run the test suite
|
||||
$(GO) test ./... -count=1 -timeout $(TEST_TIMEOUT)
|
||||
|
||||
.PHONY: vendor
|
||||
vendor: ## Refresh and vendor dependencies
|
||||
$(GO) mod tidy && $(GO) mod vendor
|
||||
|
||||
# yaegi-validate interprets the plugin under the yaegi interpreter the same way
|
||||
# Traefik loads it. Native `go build`/`go test` use the standard compiler and do
|
||||
# NOT catch yaegi-only incompatibilities (unsupported stdlib symbols, reflection
|
||||
# edge cases). This target does. Importing the package forces yaegi to interpret
|
||||
# every file in it plus its vendored deps; CreateConfig + New exercise the
|
||||
# instantiation path. Pin YAEGI_VERSION to match Traefik's bundled yaegi if you
|
||||
# need exact parity.
|
||||
.PHONY: yaegi-validate
|
||||
yaegi-validate: ## Verify the plugin loads under Traefik's yaegi interpreter
|
||||
@command -v yaegi >/dev/null 2>&1 || { echo ">> installing yaegi@$(YAEGI_VERSION)"; $(GO) install github.com/traefik/yaegi/cmd/yaegi@$(YAEGI_VERSION); }
|
||||
@echo ">> interpreting plugin under yaegi (as Traefik does)"
|
||||
@DO_NOT_TRACK=1 GOFLAGS=-mod=vendor $$(command -v yaegi || echo "$(GOPATH)/bin/yaegi") run ./cmd/yaegicheck/main.go
|
||||
|
||||
.PHONY: check
|
||||
check: vet staticcheck test yaegi-validate ## vet + staticcheck + tests + yaegi load validation
|
||||
@@ -111,7 +111,8 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `logoutURL` | `callbackURL + "/logout"` | RP-initiated logout path. |
|
||||
| `postLogoutRedirectURI` | `/` | Where to send users after logout. |
|
||||
| `scopes` | appended to `openid profile email` | Extra OAuth scopes. Set `overrideScopes: true` to replace defaults. |
|
||||
| `excludedURLs` | none | Prefix-matched paths that bypass auth. |
|
||||
| `extraAuthParams` | none | Map of extra query parameters appended to the authorization request (e.g. `screen_hint: signup`, `login_hint`, `ui_locales`, `prompt`). Plugin-managed params (`client_id`, `state`, `nonce`, `redirect_uri`, `code_challenge`, `scope`, `response_type`, …) cannot be overridden. |
|
||||
| `excludedURLs` | none | Paths that bypass auth, matched at a path-segment or file-extension boundary (e.g. `/public` matches `/public`, `/public/sub` and `/public.json`, but **not** `/publicsecret`). |
|
||||
| `allowedUserDomains` | none | Restrict to email domains. |
|
||||
| `allowedUsers` | none | Restrict to specific addresses (or claim values when `userIdentifierClaim != email`). |
|
||||
| `allowedRolesAndGroups` | none | Require any of these roles/groups from ID-token claims. |
|
||||
@@ -146,6 +147,18 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
|
||||
## Production gotchas
|
||||
|
||||
### Upgrading from an earlier release
|
||||
|
||||
- **Sessions are re-issued once.** Session cookies are now AES-256 encrypted
|
||||
(previously signed only) and their cryptographic lifetime tracks
|
||||
`sessionMaxAge` (previously a fixed 30 days). Existing cookies become invalid
|
||||
on upgrade, so users re-authenticate one time.
|
||||
- **Invalid configuration now fails closed at startup** instead of being
|
||||
silently accepted: a `sessionEncryptionKey` shorter than 32 bytes, a
|
||||
`rateLimit` below 10, a missing `callbackURL`, or a non-HTTPS remote
|
||||
`providerURL` are rejected. Plaintext HTTP is permitted only for loopback
|
||||
hosts (local development).
|
||||
|
||||
### TLS termination at a load balancer
|
||||
|
||||
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
|
||||
@@ -165,6 +178,8 @@ detected" when the same token hits different replicas. Two options:
|
||||
|
||||
For IdP-initiated logout (back/front-channel) in multi-replica setups, Redis is
|
||||
**required** so a logout on one instance invalidates sessions on the others.
|
||||
Front-channel logout requests must include a matching `iss` query parameter;
|
||||
requests that omit it are rejected with `400`.
|
||||
|
||||
### Multiple middleware instances on the same host
|
||||
|
||||
|
||||
+9
-1
@@ -182,6 +182,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
}
|
||||
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
if t.enablePKCE && codeVerifier == "" {
|
||||
t.logger.Error("PKCE is enabled but code verifier is missing from session during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: PKCE verifier missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
@@ -263,7 +268,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
// Neutralize open-redirect payloads (e.g. //evil.com, /\evil.com) stored
|
||||
// from the original request target before using it as the post-login
|
||||
// redirect target. normalizeLogoutPath forces a host-relative path.
|
||||
redirectPath = normalizeLogoutPath(incomingPath)
|
||||
}
|
||||
session.SetIncomingPath("")
|
||||
|
||||
|
||||
+101
-10
@@ -149,6 +149,94 @@ func parseBearerJOSEHeader(token string) *bearerError {
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerClaimRuneReason reports why a rune is unsafe to inject into a request
|
||||
// header value, or "" if the rune is acceptable. Shared core of the bearer-path
|
||||
// identifier sanitizer and the cookie-path header claim sanitizer: rejects
|
||||
// control chars (CRLF/header injection), Unicode bidi-override runes (RTL
|
||||
// spoofing of admin UI / SIEM), and the delimiters , ; = (a comma in a group
|
||||
// name would inject extra entries into a comma-joined header).
|
||||
func headerClaimRuneReason(r rune) string {
|
||||
if reason := headerInjectionRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
// The , ; = delimiters are only unsafe for values placed into delimited or
|
||||
// list contexts (a comma-joined header, or an identifier downstreams may
|
||||
// split). They are valid in arbitrary single header values, so this stricter
|
||||
// check is used for the cookie-path identifier and the group/role list, NOT
|
||||
// for free-form templated header output (see headerValueReason).
|
||||
if r == ',' || r == ';' || r == '=' {
|
||||
return "delimiter character"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerInjectionRuneReason reports why a rune is unsafe in ANY HTTP header
|
||||
// value, or "" if acceptable. Rejects control characters (CR/LF header
|
||||
// injection) and Unicode bidi-override runes (RTL spoofing of admin UIs/SIEMs).
|
||||
// Unlike headerClaimRuneReason it does NOT reject , ; = which are legitimate in
|
||||
// free-form header values (e.g. an opaque "Authorization: Bearer <token>").
|
||||
func headerInjectionRuneReason(r rune) string {
|
||||
if unicode.IsControl(r) {
|
||||
return "control character"
|
||||
}
|
||||
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
|
||||
return "bidi-override character"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerValueReason reports why value is unsafe to forward as a free-form HTTP
|
||||
// header value, or "" if acceptable. It rejects values over maxLen (maxLen<=0
|
||||
// disables the check) and values containing control or bidi-override runes, but
|
||||
// permits , ; = (valid in header values). Empty is allowed. The reason string
|
||||
// never includes the value, so it is safe to log.
|
||||
func headerValueReason(value string, maxLen int) string {
|
||||
if maxLen > 0 && len(value) > maxLen {
|
||||
return "exceeds max length"
|
||||
}
|
||||
for _, r := range value {
|
||||
if reason := headerInjectionRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// headerClaimValueReason reports why value is unsafe to inject into a
|
||||
// downstream request header, or "" if it is acceptable. It rejects empty
|
||||
// values, values exceeding maxLen (maxLen<=0 disables the length check), and
|
||||
// values containing any rune rejected by headerClaimRuneReason. The reason
|
||||
// string is safe to log (it never includes the value itself).
|
||||
func headerClaimValueReason(value string, maxLen int) string {
|
||||
if value == "" {
|
||||
return "empty value"
|
||||
}
|
||||
if maxLen > 0 && len(value) > maxLen {
|
||||
return "exceeds max length"
|
||||
}
|
||||
for _, r := range value {
|
||||
if reason := headerClaimRuneReason(r); reason != "" {
|
||||
return reason
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// sanitizeHeaderClaimValue validates a claim-derived value before it is
|
||||
// injected into a downstream request header. It trims surrounding whitespace
|
||||
// and fails closed (ok=false) on empty values, values exceeding maxLen
|
||||
// (maxLen<=0 disables the length check), or values containing any rune rejected
|
||||
// by headerClaimRuneReason. Used by the cookie/session path, which — unlike the
|
||||
// bearer path — does not otherwise sanitize the principal identifier or the
|
||||
// group/role strings joined into X-User-Groups / X-User-Roles.
|
||||
func sanitizeHeaderClaimValue(raw string, maxLen int) (string, bool) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if headerClaimValueReason(value, maxLen) != "" {
|
||||
return "", false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
// sanitizeBearerIdentifier validates and trims a principal identifier before
|
||||
// it is injected into request headers. Layered defense: net/http will reject
|
||||
// CRLF on the wire too, but rejecting early gives clearer error logs and
|
||||
@@ -163,15 +251,8 @@ func sanitizeBearerIdentifier(raw string, maxLen int) (string, *bearerError) {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier exceeds max length")
|
||||
}
|
||||
for _, r := range identifier {
|
||||
if unicode.IsControl(r) {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains control character")
|
||||
}
|
||||
// Unicode bidi-override range (RTL spoofing of admin UI / SIEM).
|
||||
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains bidi-override character")
|
||||
}
|
||||
if r == ',' || r == ';' || r == '=' {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains delimiter character")
|
||||
if reason := headerClaimRuneReason(r); reason != "" {
|
||||
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains "+reason)
|
||||
}
|
||||
}
|
||||
return identifier, nil
|
||||
@@ -342,7 +423,17 @@ func (b *bearerFailureTracker) recordSuccess(ip string) {
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
delete(b.entries, ip)
|
||||
e, ok := b.entries[ip]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Preserve an active penalty so a single success cannot wipe an in-effect
|
||||
// lockout; only reset the counter when no penalty is active or it has expired.
|
||||
now := time.Now()
|
||||
if e.penaltyUntil.IsZero() || now.After(e.penaltyUntil) {
|
||||
e.count = 0
|
||||
e.firstFailureAt = now
|
||||
}
|
||||
}
|
||||
|
||||
// clientIPForBearer returns the source IP used to key the failure tracker.
|
||||
|
||||
+45
-27
@@ -67,31 +67,31 @@ func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
|
||||
t.Helper()
|
||||
sm := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("error"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestStarted: 1,
|
||||
next: next,
|
||||
logger: NewLogger("error"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://issuer.example.com",
|
||||
audience: "https://api.example.com",
|
||||
clientID: "https://api.example.com",
|
||||
tokenCache: NewTokenCache(),
|
||||
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
|
||||
ctx: context.Background(),
|
||||
enableBearerAuth: true,
|
||||
stripAuthorizationHeader: true,
|
||||
bearerEmitWWWAuthenticate: true,
|
||||
bearerOverridesCookie: false,
|
||||
bearerIdentifierClaim: "sub",
|
||||
maxIdentifierLength: 256,
|
||||
maxTokenAge: 24 * time.Hour,
|
||||
bearerFailureThreshold: 20,
|
||||
bearerFailureWindow: 60 * time.Second,
|
||||
bearerFailurePenalty: 60 * time.Second,
|
||||
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
|
||||
issuerURL: "https://issuer.example.com",
|
||||
audience: "https://api.example.com",
|
||||
clientID: "https://api.example.com",
|
||||
tokenCache: NewTokenCache(),
|
||||
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
|
||||
ctx: context.Background(),
|
||||
enableBearerAuth: true,
|
||||
stripAuthorizationHeader: true,
|
||||
bearerEmitWWWAuthenticate: true,
|
||||
bearerOverridesCookie: false,
|
||||
bearerIdentifierClaim: "sub",
|
||||
maxIdentifierLength: 256,
|
||||
maxTokenAge: 24 * time.Hour,
|
||||
bearerFailureThreshold: 20,
|
||||
bearerFailureWindow: 60 * time.Second,
|
||||
bearerFailurePenalty: 60 * time.Second,
|
||||
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
|
||||
}
|
||||
oidc.extractClaimsFunc = extractClaims
|
||||
close(oidc.initComplete)
|
||||
@@ -303,15 +303,33 @@ func TestBearerFailureTracker(t *testing.T) {
|
||||
if b, retry := tr.blocked(ip); !b || retry <= 0 {
|
||||
t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry)
|
||||
}
|
||||
// Success clears the counter.
|
||||
// A success while a penalty is active must NOT wipe the in-effect lockout
|
||||
// (otherwise a single success could clear an attacker's penalty).
|
||||
tr.recordSuccess(ip)
|
||||
if b, _ := tr.blocked(ip); b {
|
||||
t.Fatalf("expected unblocked after success")
|
||||
if b, _ := tr.blocked(ip); !b {
|
||||
t.Fatalf("expected still blocked after success while penalty active")
|
||||
}
|
||||
// Other IPs are unaffected.
|
||||
if b, _ := tr.blocked("10.0.0.2"); b {
|
||||
t.Fatalf("unrelated IP should not be blocked")
|
||||
}
|
||||
|
||||
// With an expired penalty, a success resets the counter so a subsequent
|
||||
// sub-threshold failure does not immediately re-block.
|
||||
tr2 := newBearerFailureTracker(3, 60*time.Second, 1*time.Millisecond)
|
||||
const ip2 = "10.0.0.3"
|
||||
for i := 0; i < 3; i++ {
|
||||
tr2.recordFailure(ip2)
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond) // let the short penalty expire
|
||||
if b, _ := tr2.blocked(ip2); b {
|
||||
t.Fatalf("expected unblocked after penalty expiry")
|
||||
}
|
||||
tr2.recordSuccess(ip2) // resets count since penalty has passed
|
||||
tr2.recordFailure(ip2) // single failure, well below threshold
|
||||
if b, _ := tr2.blocked(ip2); b {
|
||||
t.Fatalf("expected unblocked: counter should have reset after success")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
+23
-2
@@ -16,8 +16,9 @@ type CacheManager struct {
|
||||
}
|
||||
|
||||
var (
|
||||
globalCacheManagerInstance *CacheManager
|
||||
cacheManagerInitOnce sync.Once
|
||||
globalCacheManagerInstance *CacheManager
|
||||
cacheManagerInitOnce sync.Once
|
||||
cacheManagerActiveFingerprint string
|
||||
)
|
||||
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance.
|
||||
@@ -29,7 +30,9 @@ func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
|
||||
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
|
||||
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
|
||||
fp := redisFingerprint(config)
|
||||
cacheManagerInitOnce.Do(func() {
|
||||
cacheManagerActiveFingerprint = fp
|
||||
var redisConfig *RedisConfig
|
||||
var logger *Logger
|
||||
|
||||
@@ -55,9 +58,27 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
|
||||
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
|
||||
}
|
||||
})
|
||||
// Warn loudly if a later instance asks for a DIFFERENT explicit Redis
|
||||
// backend than the one that won initialization: the cache manager is a
|
||||
// process-global singleton shared across plugin instances (yaegi), so this
|
||||
// instance's divergent configuration is silently ignored, which would
|
||||
// otherwise collapse cache/state isolation between routes (rank 9).
|
||||
if fp != "" && cacheManagerActiveFingerprint != "" && fp != cacheManagerActiveFingerprint {
|
||||
NewLogger(config.LogLevel).Errorf("cache manager already initialized with Redis backend %q; this instance's Redis backend %q is IGNORED (process-global singleton). Use a single consistent cache configuration across all routes.", cacheManagerActiveFingerprint, fp)
|
||||
}
|
||||
return globalCacheManagerInstance
|
||||
}
|
||||
|
||||
// redisFingerprint returns a stable identifier for an explicitly-enabled Redis
|
||||
// backend (address + key prefix), or "" when Redis is not explicitly enabled.
|
||||
// Used to detect divergent cache configurations across plugin instances.
|
||||
func redisFingerprint(config *Config) string {
|
||||
if config == nil || config.Redis == nil || !config.Redis.Enabled {
|
||||
return ""
|
||||
}
|
||||
return config.Redis.Address + "|" + config.Redis.KeyPrefix
|
||||
}
|
||||
|
||||
// GetSharedTokenBlacklist returns the shared token blacklist cache
|
||||
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
//go:build ignore
|
||||
|
||||
// Command yaegicheck verifies that the traefikoidc plugin can be imported and
|
||||
// instantiated by the yaegi interpreter — the same way Traefik loads a plugin.
|
||||
//
|
||||
// It is run by `make yaegi-validate`. Importing the plugin package forces yaegi
|
||||
// to interpret every source file in the package (and its vendored
|
||||
// dependencies), so any construct yaegi cannot handle (unsupported stdlib
|
||||
// symbol, reflection edge case, etc.) surfaces here rather than at Traefik load
|
||||
// time. CreateConfig + New additionally exercise the instantiation path
|
||||
// (session manager, cookie codec, caches, key derivation) under the interpreter.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
oidc "github.com/lukaszraczylo/traefikoidc"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := oidc.CreateConfig()
|
||||
cfg.ProviderURL = "https://accounts.google.com"
|
||||
cfg.ClientID = "yaegi-check-client"
|
||||
cfg.ClientSecret = "yaegi-check-secret"
|
||||
cfg.CallbackURL = "/oauth2/callback"
|
||||
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef"
|
||||
cfg.RateLimit = 100
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
h, err := oidc.New(context.Background(), next, cfg, "yaegi-check")
|
||||
if err != nil {
|
||||
fmt.Println("FAIL: New returned an error under yaegi:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if h == nil {
|
||||
fmt.Println("FAIL: New returned a nil handler under yaegi")
|
||||
os.Exit(1)
|
||||
}
|
||||
if closer, ok := h.(interface{ Close() error }); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
fmt.Println("OK: traefikoidc imported + CreateConfig + New succeeded under yaegi")
|
||||
}
|
||||
@@ -278,82 +278,6 @@ func TestHTTPClientProfiler_Methods_CoverageBoost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SECURITY MONITORING TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestSecurityMonitor_StopCleanupRoutine_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
config := SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 5,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 30,
|
||||
RapidFailureThreshold: 3,
|
||||
CleanupIntervalMinutes: 60,
|
||||
RetentionHours: 24,
|
||||
EnablePatternDetection: true,
|
||||
EnableDetailedLogging: false,
|
||||
LogSuspiciousOnly: false,
|
||||
}
|
||||
|
||||
sm := NewSecurityMonitor(config, logger)
|
||||
if sm == nil {
|
||||
t.Fatal("Expected non-nil SecurityMonitor")
|
||||
}
|
||||
|
||||
// Start cleanup routine first (lowercase method)
|
||||
sm.startCleanupRoutine()
|
||||
|
||||
// Give it a moment to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop cleanup routine (public method)
|
||||
sm.StopCleanupRoutine()
|
||||
|
||||
// Stop again should be safe
|
||||
sm.StopCleanupRoutine()
|
||||
}
|
||||
|
||||
func TestSecurityMonitor_MultipleHandlers_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
config := SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 5,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 30,
|
||||
RapidFailureThreshold: 3,
|
||||
CleanupIntervalMinutes: 60,
|
||||
RetentionHours: 24,
|
||||
}
|
||||
|
||||
sm := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Create handler
|
||||
handler := &LoggingSecurityEventHandler{logger: logger}
|
||||
|
||||
// Register handler using AddEventHandler
|
||||
sm.AddEventHandler(handler)
|
||||
|
||||
// Record a failure to trigger events
|
||||
sm.RecordAuthenticationFailure("192.168.1.100", "test-agent", "/test", "test_failure", nil)
|
||||
}
|
||||
|
||||
func TestLoggingSecurityEventHandler_HandleSecurityEvent_AllSeverities_CoverageBoost(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
handler := &LoggingSecurityEventHandler{logger: logger}
|
||||
|
||||
// Severity is a string in this implementation
|
||||
events := []SecurityEvent{
|
||||
{Type: "test", Severity: "low", Message: "low severity"},
|
||||
{Type: "test", Severity: "medium", Message: "medium severity"},
|
||||
{Type: "test", Severity: "high", Message: "high severity"},
|
||||
{Type: "test", Severity: "critical", Message: "critical severity"},
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
handler.HandleSecurityEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SESSION MANAGER TESTS
|
||||
// =============================================================================
|
||||
|
||||
@@ -178,7 +178,7 @@ clientSecret: your-client-secret
|
||||
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
|
||||
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
|
||||
| `rateLimit` | int | `100` | Maximum requests per second |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication, matched at a path-segment or file-extension boundary |
|
||||
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
|
||||
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
|
||||
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
|
||||
|
||||
@@ -370,21 +370,6 @@ func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, res
|
||||
return r.saveCredentials(resp)
|
||||
}
|
||||
|
||||
// deleteCredentialsFromStore removes credentials from the configured storage backend
|
||||
// Falls back to legacy file-based deletion if no store is configured
|
||||
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Delete(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based deletion
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
@@ -423,187 +408,3 @@ func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse,
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateClientRegistration updates an existing client registration using RFC 7592
|
||||
// This requires the registration_client_uri and registration_access_token from the original registration
|
||||
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to update")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Build update request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build update request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create update request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read update response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse update response: %w", err)
|
||||
}
|
||||
|
||||
// Update cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// ReadClientRegistration reads the current client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to read")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create read request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse read response: %w", err)
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// DeleteClientRegistration deletes the client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return fmt.Errorf("no existing registration to delete")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create delete request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle error responses (204 No Content is success)
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials from storage if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.deleteCredentialsFromStore(ctx); err != nil {
|
||||
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("Successfully deleted client registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -735,258 +735,6 @@ func TestDCRConfigDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateClientRegistration tests the RFC 7592 client update functionality
|
||||
func TestUpdateClientRegistration(t *testing.T) {
|
||||
updateCalled := false
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPut {
|
||||
updateCalled = true
|
||||
|
||||
// Verify authorization header
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
t.Error("Missing Authorization header for update")
|
||||
}
|
||||
|
||||
resp := ClientRegistrationResponse{
|
||||
ClientID: "updated-client-id",
|
||||
ClientSecret: "updated-client-secret",
|
||||
RegistrationAccessToken: "new-access-token",
|
||||
RegistrationClientURI: r.URL.String(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
ClientMetadata: &ClientRegistrationMetadata{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
},
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "original-client-id",
|
||||
ClientSecret: "original-client-secret",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Perform update
|
||||
ctx := context.Background()
|
||||
resp, err := registrar.UpdateClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
if !updateCalled {
|
||||
t.Error("Update endpoint was not called")
|
||||
}
|
||||
|
||||
if resp.ClientID != "updated-client-id" {
|
||||
t.Errorf("Updated ClientID mismatch: got %s", resp.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteClientRegistration tests the RFC 7592 client deletion functionality
|
||||
func TestDeleteClientRegistration(t *testing.T) {
|
||||
deleteCalled := false
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodDelete {
|
||||
deleteCalled = true
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
credentialsFile := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
// Create a credentials file to test deletion
|
||||
os.WriteFile(credentialsFile, []byte(`{"client_id":"test"}`), 0600)
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
CredentialsFile: credentialsFile,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Perform delete
|
||||
ctx := context.Background()
|
||||
err := registrar.DeleteClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if !deleteCalled {
|
||||
t.Error("Delete endpoint was not called")
|
||||
}
|
||||
|
||||
// Verify cache is cleared
|
||||
if registrar.GetCachedResponse() != nil {
|
||||
t.Error("Cached response should be cleared after deletion")
|
||||
}
|
||||
|
||||
// Verify credentials file is deleted
|
||||
if _, err := os.Stat(credentialsFile); !os.IsNotExist(err) {
|
||||
t.Error("Credentials file should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadClientRegistration tests the RFC 7592 client read functionality
|
||||
func TestReadClientRegistration(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet {
|
||||
resp := ClientRegistrationResponse{
|
||||
ClientID: "read-client-id",
|
||||
ClientSecret: "read-client-secret",
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ResponseTypes: []string{"code"},
|
||||
GrantTypes: []string{"authorization_code"},
|
||||
ApplicationType: "web",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
server.Client(),
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
// Set up cached response with management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "original-client-id",
|
||||
RegistrationAccessToken: "access-token",
|
||||
RegistrationClientURI: server.URL + "/register/client123",
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
// Read registration
|
||||
ctx := context.Background()
|
||||
resp, err := registrar.ReadClientRegistration(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.ClientID != "read-client-id" {
|
||||
t.Errorf("Read ClientID mismatch: got %s", resp.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOperationsWithoutCachedResponse tests error handling when no cached response exists
|
||||
func TestOperationsWithoutCachedResponse(t *testing.T) {
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
&http.Client{},
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
"https://example.com",
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Update without cached response
|
||||
_, err := registrar.UpdateClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Update should fail without cached response: %v", err)
|
||||
}
|
||||
|
||||
// Test Read without cached response
|
||||
_, err = registrar.ReadClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Read should fail without cached response: %v", err)
|
||||
}
|
||||
|
||||
// Test Delete without cached response
|
||||
err = registrar.DeleteClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "no existing registration") {
|
||||
t.Errorf("Delete should fail without cached response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOperationsWithoutManagementCredentials tests error handling without management URIs
|
||||
func TestOperationsWithoutManagementCredentials(t *testing.T) {
|
||||
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
|
||||
|
||||
registrar := NewDynamicClientRegistrar(
|
||||
&http.Client{},
|
||||
NewLogger("DEBUG"),
|
||||
dcrConfig,
|
||||
"https://example.com",
|
||||
)
|
||||
|
||||
// Set up cached response WITHOUT management credentials
|
||||
registrar.mu.Lock()
|
||||
registrar.registrationResponse = &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
// Missing RegistrationAccessToken and RegistrationClientURI
|
||||
}
|
||||
registrar.mu.Unlock()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Update without management credentials
|
||||
_, err := registrar.UpdateClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Update should fail without management credentials: %v", err)
|
||||
}
|
||||
|
||||
// Test Read without management credentials
|
||||
_, err = registrar.ReadClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Read should fail without management credentials: %v", err)
|
||||
}
|
||||
|
||||
// Test Delete without management credentials
|
||||
err = registrar.DeleteClientRegistration(ctx)
|
||||
if err == nil || !stringContains(err.Error(), "registration management not supported") {
|
||||
t.Errorf("Delete should fail without management credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stringContains is a helper function to check if a string contains a substring
|
||||
func stringContains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContainsHelper(s, substr))
|
||||
|
||||
+6
-5
@@ -539,10 +539,10 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
errStr := strings.ToLower(err.Error())
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
if contains(errStr, retryableErr) {
|
||||
if contains(errStr, strings.ToLower(retryableErr)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -551,7 +551,7 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if netErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
errStr := netErr.Error()
|
||||
errStr := strings.ToLower(netErr.Error())
|
||||
temporaryPatterns := []string{
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
@@ -859,8 +859,9 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f
|
||||
|
||||
// isServiceDegraded checks if a service is currently degraded
|
||||
func (gd *GracefulDegradation) isServiceDegraded(serviceName string) bool {
|
||||
gd.mutex.RLock()
|
||||
defer gd.mutex.RUnlock()
|
||||
// Uses a write lock because the recovery-timeout branch deletes from the map.
|
||||
gd.mutex.Lock()
|
||||
defer gd.mutex.Unlock()
|
||||
|
||||
degradedTime, exists := gd.degradedServices[serviceName]
|
||||
if !exists {
|
||||
|
||||
@@ -5,6 +5,7 @@ go 1.24.0
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
|
||||
@@ -16,6 +16,8 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3 h1:xoDtBqeZGmXj7IteiE1M5WMuzeoqag58qEleI0Cf2Ms=
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3/go.mod h1:+Cn78qZo8rc3T9eZt0v3oICYRdd75wORtSidc8lNjDQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
|
||||
+10
-1
@@ -392,10 +392,19 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
// localRedirect is used when there is no provider end-session endpoint and
|
||||
// the plugin redirects the browser itself. It must never be an absolute URL
|
||||
// derived from the request host (X-Forwarded-Host is client-controllable and
|
||||
// would be an open redirect); use a host-relative path, or the operator's
|
||||
// own configured absolute URL, instead.
|
||||
localRedirect := "/"
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
||||
localRedirect = normalizeLogoutPath(postLogoutRedirectURI)
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
} else {
|
||||
localRedirect = postLogoutRedirectURI
|
||||
}
|
||||
|
||||
// Read endSessionURL with RLock
|
||||
@@ -414,7 +423,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
http.Redirect(rw, req, localRedirect, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs a logout URL for the OIDC provider's end session endpoint.
|
||||
|
||||
+30
-6
@@ -26,6 +26,10 @@ type sharedTransport struct {
|
||||
lastUsed time.Time
|
||||
transport *http.Transport
|
||||
refCount int
|
||||
// tlsKey identifies the TLS trust settings (CA pool + InsecureSkipVerify)
|
||||
// this transport was built with, so the at-limit fallback only reuses a
|
||||
// transport whose TLS configuration matches the caller's.
|
||||
tlsKey string
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -53,19 +57,26 @@ func GetGlobalTransportPool() *SharedTransportPool {
|
||||
|
||||
// GetOrCreateTransport gets or creates a shared transport with the given config
|
||||
func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *http.Transport {
|
||||
// SECURITY FIX: Check client limit before creating new transport
|
||||
// SECURITY FIX: Check client limit before creating new transport.
|
||||
if atomic.LoadInt32(&p.clientCount) >= p.maxClients {
|
||||
// Return existing transport if limit reached
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
// At the client limit: only reuse a transport that was built for the
|
||||
// SAME config (same TLS trust store). refCount is mutated under the
|
||||
// write lock to avoid a data race, and a transport created for a
|
||||
// different configuration is never handed back — doing so could apply
|
||||
// the wrong (possibly verification-disabled) TLS settings to a request.
|
||||
want := tlsConfigKey(config)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
for _, shared := range p.transports {
|
||||
if shared != nil && shared.transport != nil {
|
||||
if shared != nil && shared.transport != nil && shared.tlsKey == want {
|
||||
shared.refCount++
|
||||
shared.lastUsed = time.Now()
|
||||
return shared.transport
|
||||
}
|
||||
}
|
||||
// If no transport available, return nil (caller should handle)
|
||||
// No TLS-compatible transport available; return nil so the caller falls
|
||||
// back to a default, certificate-verifying transport rather than one
|
||||
// with a different (possibly verification-disabled) trust store.
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -125,6 +136,7 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
|
||||
transport: transport,
|
||||
refCount: 1,
|
||||
lastUsed: time.Now(),
|
||||
tlsKey: tlsConfigKey(config),
|
||||
}
|
||||
|
||||
return transport
|
||||
@@ -224,6 +236,18 @@ func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
|
||||
)
|
||||
}
|
||||
|
||||
// tlsConfigKey identifies only the TLS trust settings of a config — the CA pool
|
||||
// and the InsecureSkipVerify flag. Two configs with the same tlsConfigKey are
|
||||
// safe to serve from the same transport even if other (non-TLS) parameters such
|
||||
// as connection limits differ; configs with different TLS settings are not.
|
||||
func tlsConfigKey(config HTTPClientConfig) string {
|
||||
skip := "0"
|
||||
if config.InsecureSkipVerify {
|
||||
skip = "1"
|
||||
}
|
||||
return fmt.Sprintf("%p|%s", config.RootCAs, skip)
|
||||
}
|
||||
|
||||
// Cleanup closes all transports and stops the cleanup goroutine
|
||||
func (p *SharedTransportPool) Cleanup() {
|
||||
p.mu.Lock()
|
||||
|
||||
@@ -842,10 +842,18 @@ func TestWorkerPool_TaskPanic(t *testing.T) {
|
||||
t.Error("Timeout waiting for tasks")
|
||||
}
|
||||
|
||||
// Pool should still be functional
|
||||
metrics := pool.GetMetrics()
|
||||
if metrics["tasksFailed"].(int64) < 1 {
|
||||
t.Error("Expected at least one failed task")
|
||||
// tasksFailed is incremented in the worker's deferred recover(), which runs
|
||||
// AFTER the panicking task's own `defer wg.Done()`. wg.Wait() above can
|
||||
// therefore return before the failure is recorded — reading the counter
|
||||
// immediately is a race that flakes on slow/contended CI runners. Poll until
|
||||
// the failure lands (or time out).
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for pool.GetMetrics()["tasksFailed"].(int64) < 1 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Error("Expected at least one failed task")
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+23
-1
@@ -155,12 +155,34 @@ func DetermineScheme(req *http.Request, forceHTTPS bool) string {
|
||||
// It checks X-Forwarded-Host header first (for proxy scenarios),
|
||||
// then falls back to req.Host.
|
||||
func DetermineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
if host := sanitizeForwardedHost(req.Header.Get("X-Forwarded-Host")); host != "" {
|
||||
return host
|
||||
}
|
||||
return req.Host
|
||||
}
|
||||
|
||||
// sanitizeForwardedHost returns a single, well-formed host from a (possibly
|
||||
// comma-separated) X-Forwarded-Host header, or "" if none is usable. It takes
|
||||
// only the first value and rejects whitespace and control characters, so a
|
||||
// crafted header cannot inject CRLF, smuggle a second host, or otherwise poison
|
||||
// the redirect URLs built from the result.
|
||||
func sanitizeForwardedHost(v string) string {
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
if i := strings.IndexByte(v, ','); i >= 0 {
|
||||
v = v[:i]
|
||||
}
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.IndexFunc(v, func(r rune) bool { return r < 0x20 || r == 0x7f || r == ' ' }) >= 0 {
|
||||
return ""
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// BuildFullURL constructs a URL from scheme, host, and path components.
|
||||
// It handles absolute URLs (returning them as-is) and ensures paths have leading slashes.
|
||||
func BuildFullURL(scheme, host, path string) string {
|
||||
|
||||
@@ -200,6 +200,22 @@ func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
|
||||
if k.Kid == "" {
|
||||
continue
|
||||
}
|
||||
// Skip keys that are not intended for signature verification.
|
||||
if k.Use != "" && k.Use != "sig" {
|
||||
continue
|
||||
}
|
||||
if len(k.KeyOps) > 0 {
|
||||
hasVerify := false
|
||||
for _, op := range k.KeyOps {
|
||||
if op == "verify" {
|
||||
hasVerify = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasVerify {
|
||||
continue
|
||||
}
|
||||
}
|
||||
var pub crypto.PublicKey
|
||||
var err error
|
||||
switch k.Kty {
|
||||
@@ -242,11 +258,11 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
|
||||
defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body) // Safe to ignore: reading error body for diagnostics
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024)) // Safe to ignore: reading error body for diagnostics
|
||||
return nil, fmt.Errorf("JWKS fetch failed with status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading JWKS response: %w", err)
|
||||
}
|
||||
|
||||
@@ -134,8 +134,11 @@ func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http
|
||||
expectedIssuer := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if iss != "" && iss != expectedIssuer {
|
||||
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
|
||||
// Require a matching issuer. An empty iss must be rejected too: accepting a
|
||||
// missing issuer would let an unauthenticated attacker force-logout any
|
||||
// session whose sid is known by simply omitting iss.
|
||||
if iss == "" || iss != expectedIssuer {
|
||||
t.logger.Errorf("Front-channel logout: issuer validation failed: got %q, expected %q", iss, expectedIssuer)
|
||||
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
+32
-25
@@ -125,10 +125,14 @@ func TestFrontchannelLogoutBasic(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Valid front-channel logout without issuer",
|
||||
// Front-channel logout MUST carry a matching issuer. A request
|
||||
// omitting iss is rejected so an unauthenticated attacker cannot
|
||||
// force-logout a session whose sid is known by simply leaving iss
|
||||
// out (audit rank 30).
|
||||
name: "Missing issuer is rejected",
|
||||
method: http.MethodGet,
|
||||
queryParams: map[string]string{"sid": "session456"},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -407,17 +411,17 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) {
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
logger: NewLogger("debug"),
|
||||
enableBackchannelLogout: true,
|
||||
backchannelLogoutPath: "/backchannel-logout",
|
||||
sessionInvalidationCache: mockCache,
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestStarted: 1,
|
||||
next: nextHandler,
|
||||
logger: NewLogger("debug"),
|
||||
enableBackchannelLogout: true,
|
||||
backchannelLogoutPath: "/backchannel-logout",
|
||||
sessionInvalidationCache: mockCache,
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
@@ -449,22 +453,23 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
logger: NewLogger("debug"),
|
||||
enableFrontchannelLogout: true,
|
||||
frontchannelLogoutPath: "/frontchannel-logout",
|
||||
sessionInvalidationCache: mockCache,
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestStarted: 1,
|
||||
next: nextHandler,
|
||||
logger: NewLogger("debug"),
|
||||
enableFrontchannelLogout: true,
|
||||
frontchannelLogoutPath: "/frontchannel-logout",
|
||||
sessionInvalidationCache: mockCache,
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Request to front-channel logout path with valid sid should succeed
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session", nil)
|
||||
// Request to front-channel logout path with valid sid + matching issuer
|
||||
// should succeed. The issuer is now required (audit rank 30), so supply it.
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session&iss=https://provider.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
@@ -1432,7 +1437,9 @@ func TestFrontchannelLogoutCacheControl(t *testing.T) {
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123", nil)
|
||||
// Issuer is now required (audit rank 30); supply a matching one so the
|
||||
// successful-logout cache headers can be asserted.
|
||||
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123&iss=https://provider.example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.handleFrontchannelLogout(rw, req)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
telemetry "github.com/lukaszraczylo/oss-telemetry"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@@ -23,6 +25,11 @@ const (
|
||||
ConstSessionTimeout = 86400
|
||||
)
|
||||
|
||||
// telemetryStartupOnce keeps the anonymous "plugin loaded" ping to one per
|
||||
// process. Traefik calls New once per route that uses the plugin; oss-telemetry
|
||||
// does not deduplicate client-side (the server does), so the gate stays here.
|
||||
var telemetryStartupOnce sync.Once
|
||||
|
||||
// isTestMode detects if the code is running in a test environment.
|
||||
func isTestMode() bool {
|
||||
if os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS") == "1" {
|
||||
@@ -89,7 +96,13 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
// - The configured TraefikOidc handler ready to process requests.
|
||||
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
sendTelemetry(pluginVersion)
|
||||
telemetryStartupOnce.Do(func() {
|
||||
// Only stamped release builds phone home; dev/local/test builds keep the
|
||||
// devPluginVersion sentinel (see version.go) and stay silent.
|
||||
if traefikoidcPluginVersion != devPluginVersion {
|
||||
telemetry.Send("traefikoidc", traefikoidcPluginVersion)
|
||||
}
|
||||
})
|
||||
return NewWithContext(ctx, config, next, name)
|
||||
}
|
||||
|
||||
@@ -100,18 +113,18 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
config = CreateConfig()
|
||||
}
|
||||
|
||||
if config.SessionEncryptionKey == "" {
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
}
|
||||
|
||||
logger := NewLogger(config.LogLevel)
|
||||
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
if runtime.Compiler == "yaegi" {
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
logger.Infof("Session encryption key is too short; using default key for analyzer")
|
||||
} else {
|
||||
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||
}
|
||||
|
||||
// Fail closed on invalid configuration. Validate() enforces the security
|
||||
// constraints (required fields, HTTPS-only URLs, key length, excludedURLs
|
||||
// safety, rate-limit floor, audience format, ...) that were previously
|
||||
// unenforced because this constructor never called it. Crucially it rejects
|
||||
// an empty or too-short SessionEncryptionKey instead of silently
|
||||
// substituting a public hardcoded key, which would let an attacker forge
|
||||
// any session. Traefik's yaegi plugin analyzer supplies a valid key via
|
||||
// .traefik.yml testData, so it passes; only misconfigured deployments fail.
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
// Setup HTTP client
|
||||
caPool, err := config.loadCACertPool()
|
||||
@@ -202,6 +215,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}(),
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
enablePKCE: config.EnablePKCE,
|
||||
extraAuthParams: config.ExtraAuthParams,
|
||||
overrideScopes: config.OverrideScopes,
|
||||
strictAudienceValidation: config.StrictAudienceValidation,
|
||||
allowOpaqueTokens: config.AllowOpaqueTokens,
|
||||
@@ -222,7 +236,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
httpClient: httpClient,
|
||||
tokenHTTPClient: tokenHTTPClient,
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedUserDomains: createCaseInsensitiveStringMap(config.AllowedUserDomains),
|
||||
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
@@ -333,7 +347,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
|
||||
// Convert sessionMaxAge from seconds to duration (0 will use default 24 hours)
|
||||
sessionMaxAge := time.Duration(config.SessionMaxAge) * time.Second
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger) // Safe to ignore: session manager creation with fallback to defaults
|
||||
sessionManager, err := NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
return nil, fmt.Errorf("failed to create session manager: %w", err)
|
||||
}
|
||||
t.sessionManager = sessionManager
|
||||
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
|
||||
|
||||
// Initialize token resilience manager with default configuration
|
||||
@@ -424,6 +443,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
|
||||
// Add reference for this instance
|
||||
rm.AddReference(name)
|
||||
registerLiveInstance()
|
||||
|
||||
// Initialize metadata in a goroutine with proper tracking
|
||||
if t.goroutineWG != nil {
|
||||
@@ -501,13 +521,58 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
// Parameters:
|
||||
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
|
||||
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
// SSRF defense (audit ranks 3 & 4): a discovery document is attacker-
|
||||
// influenced when the provider or its TLS is compromised. Reject any
|
||||
// discovered endpoint pointed at a blocked address before the plugin issues
|
||||
// outbound requests to it, so it can never be used to reach the cloud
|
||||
// metadata service or an internal host.
|
||||
allowLoopback := false
|
||||
if pu, err := url.Parse(t.providerURL); err == nil {
|
||||
allowLoopback = isLoopbackHost(pu.Hostname())
|
||||
}
|
||||
sanitize := func(name, raw string) string {
|
||||
if err := t.validateDiscoveredEndpoint(raw, allowLoopback); err != nil {
|
||||
t.logger.Errorf("Ignoring discovered %s endpoint %q: %v", name, raw, err)
|
||||
return ""
|
||||
}
|
||||
return raw
|
||||
}
|
||||
metadata.JWKSURL = sanitize("jwks_uri", metadata.JWKSURL)
|
||||
metadata.AuthURL = sanitize("authorization", metadata.AuthURL)
|
||||
metadata.TokenURL = sanitize("token", metadata.TokenURL)
|
||||
metadata.RevokeURL = sanitize("revocation", metadata.RevokeURL)
|
||||
metadata.EndSessionURL = sanitize("end_session", metadata.EndSessionURL)
|
||||
metadata.RegistrationURL = sanitize("registration", metadata.RegistrationURL)
|
||||
metadata.IntrospectionURL = sanitize("introspection", metadata.IntrospectionURL)
|
||||
// The introspection request authenticates with the client secret via HTTP
|
||||
// Basic, so the endpoint must live on the same host as the operator-
|
||||
// configured provider; otherwise a poisoned discovery document could
|
||||
// exfiltrate the client secret to an attacker-controlled host.
|
||||
if metadata.IntrospectionURL != "" && t.providerURL != "" && !sameHost(metadata.IntrospectionURL, t.providerURL) {
|
||||
t.logger.Errorf("Ignoring introspection endpoint %q: host does not match configured providerURL", metadata.IntrospectionURL)
|
||||
metadata.IntrospectionURL = ""
|
||||
}
|
||||
|
||||
// Pin the discovered issuer to the operator-configured provider host. The
|
||||
// issuer is the trust anchor for JWT issuer validation, so a poisoned
|
||||
// discovery document advertising an attacker-chosen issuer must never be
|
||||
// stored. Real providers (Google, Azure, Keycloak, Okta, Auth0) keep the
|
||||
// issuer on the same host as the configured providerURL. On mismatch, leave
|
||||
// issuerURL empty/unchanged so downstream issuer validation fails closed
|
||||
// rather than trusting the attacker-chosen value.
|
||||
discoveredIssuer := metadata.Issuer
|
||||
if discoveredIssuer != "" && t.providerURL != "" && !sameHost(discoveredIssuer, t.providerURL) {
|
||||
t.logger.Errorf("Ignoring discovered issuer %q: host does not match configured providerURL", discoveredIssuer)
|
||||
discoveredIssuer = ""
|
||||
}
|
||||
|
||||
t.metadataMu.Lock()
|
||||
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.issuerURL = discoveredIssuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662)
|
||||
@@ -520,7 +585,7 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
// Publish the read-mostly URL bundle atomically. Hot-path readers Load
|
||||
// this directly instead of acquiring metadataMu.RLock per request.
|
||||
t.metadataSnapshot.Store(&MetadataSnapshot{
|
||||
IssuerURL: metadata.Issuer,
|
||||
IssuerURL: discoveredIssuer,
|
||||
JWKSURL: metadata.JWKSURL,
|
||||
TokenURL: metadata.TokenURL,
|
||||
AuthURL: metadata.AuthURL,
|
||||
|
||||
@@ -194,6 +194,7 @@ func TestGoroutineLeakPrevention_MultipleInstances(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(ctx, nil, config, "test")
|
||||
if err != nil {
|
||||
@@ -322,6 +323,7 @@ func TestGoroutineLeakPrevention_BackgroundTaskCleanup(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(ctx, nil, config, "test")
|
||||
if err != nil {
|
||||
|
||||
+56
-38
@@ -26,38 +26,47 @@ func TestInitializeMetadata(t *testing.T) {
|
||||
name: "successful metadata initialization",
|
||||
providerURL: "",
|
||||
setupMock: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer must share the host with providerURL (the httptest
|
||||
// server), otherwise the discovery doc is rejected as poisoned
|
||||
// (audit ranks 21/22). Real providers keep issuer + endpoints on
|
||||
// the same host, so derive them all from the server URL.
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://provider.example.com",
|
||||
AuthURL: "https://provider.example.com/auth",
|
||||
TokenURL: "https://provider.example.com/token",
|
||||
JWKSURL: "https://provider.example.com/jwks",
|
||||
RevokeURL: "https://provider.example.com/revoke",
|
||||
EndSessionURL: "https://provider.example.com/logout",
|
||||
Issuer: srv.URL,
|
||||
AuthURL: srv.URL + "/auth",
|
||||
TokenURL: srv.URL + "/token",
|
||||
JWKSURL: srv.URL + "/jwks",
|
||||
RevokeURL: srv.URL + "/revoke",
|
||||
EndSessionURL: srv.URL + "/logout",
|
||||
})
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
},
|
||||
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
|
||||
if oidc.authURL != "https://provider.example.com/auth" {
|
||||
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
|
||||
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
|
||||
}
|
||||
if oidc.tokenURL != "https://provider.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
|
||||
}
|
||||
if oidc.jwksURL != "https://provider.example.com/jwks" {
|
||||
if oidc.jwksURL == "" || !strings.HasSuffix(oidc.jwksURL, "/jwks") {
|
||||
t.Errorf("expected jwksURL to be set, got %s", oidc.jwksURL)
|
||||
}
|
||||
if oidc.revocationURL != "https://provider.example.com/revoke" {
|
||||
if oidc.revocationURL == "" || !strings.HasSuffix(oidc.revocationURL, "/revoke") {
|
||||
t.Errorf("expected revocationURL to be set, got %s", oidc.revocationURL)
|
||||
}
|
||||
if oidc.endSessionURL != "https://provider.example.com/logout" {
|
||||
if oidc.endSessionURL == "" || !strings.HasSuffix(oidc.endSessionURL, "/logout") {
|
||||
t.Errorf("expected endSessionURL to be set, got %s", oidc.endSessionURL)
|
||||
}
|
||||
if oidc.issuerURL == "" {
|
||||
t.Errorf("expected issuerURL to be pinned to provider host, got empty")
|
||||
}
|
||||
},
|
||||
wantPanic: false,
|
||||
},
|
||||
@@ -116,24 +125,27 @@ func TestInitializeMetadata(t *testing.T) {
|
||||
name: "partial metadata response",
|
||||
providerURL: "",
|
||||
setupMock: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Only return some fields
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"issuer": "https://partial.example.com",
|
||||
"authorization_endpoint": "https://partial.example.com/auth",
|
||||
"token_endpoint": "https://partial.example.com/token",
|
||||
"issuer": srv.URL,
|
||||
"authorization_endpoint": srv.URL + "/auth",
|
||||
"token_endpoint": srv.URL + "/token",
|
||||
// Missing jwks_uri, revocation_endpoint, end_session_endpoint
|
||||
})
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
},
|
||||
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
|
||||
if oidc.authURL != "https://partial.example.com/auth" {
|
||||
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
|
||||
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
|
||||
}
|
||||
if oidc.tokenURL != "https://partial.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
|
||||
}
|
||||
// JWKS URL and others may be empty
|
||||
@@ -198,20 +210,22 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
|
||||
requestCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
requestCount++
|
||||
mu.Unlock()
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://concurrent.example.com",
|
||||
AuthURL: "https://concurrent.example.com/auth",
|
||||
TokenURL: "https://concurrent.example.com/token",
|
||||
JWKSURL: "https://concurrent.example.com/jwks",
|
||||
RevokeURL: "https://concurrent.example.com/revoke",
|
||||
EndSessionURL: "https://concurrent.example.com/logout",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
RevokeURL: server.URL + "/revoke",
|
||||
EndSessionURL: server.URL + "/logout",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -250,7 +264,7 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
|
||||
oidc.initializeMetadata(server.URL)
|
||||
|
||||
// Verify initialization
|
||||
if oidc.tokenURL != "https://concurrent.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("expected tokenURL to be set")
|
||||
}
|
||||
}()
|
||||
@@ -342,17 +356,19 @@ func TestProviderDetection(t *testing.T) {
|
||||
// TestInitializationWaiting tests waiting for initialization to complete
|
||||
func TestInitializationWaiting(t *testing.T) {
|
||||
t.Run("wait for initialization completion", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Delay response to simulate slow initialization
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://slow.example.com",
|
||||
AuthURL: "https://slow.example.com/auth",
|
||||
TokenURL: "https://slow.example.com/token",
|
||||
JWKSURL: "https://slow.example.com/jwks",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -389,7 +405,7 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
select {
|
||||
case <-oidc.initComplete:
|
||||
// Success
|
||||
if oidc.tokenURL != "https://slow.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Error("expected tokenURL to be set after initialization")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
@@ -398,17 +414,19 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("multiple waiters for initialization", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Delay to ensure multiple waiters
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Issuer host must match providerURL (audit ranks 21/22).
|
||||
json.NewEncoder(w).Encode(ProviderMetadata{
|
||||
Issuer: "https://multi.example.com",
|
||||
AuthURL: "https://multi.example.com/auth",
|
||||
TokenURL: "https://multi.example.com/token",
|
||||
JWKSURL: "https://multi.example.com/jwks",
|
||||
Issuer: server.URL,
|
||||
AuthURL: server.URL + "/auth",
|
||||
TokenURL: server.URL + "/token",
|
||||
JWKSURL: server.URL + "/jwks",
|
||||
})
|
||||
}
|
||||
}))
|
||||
@@ -453,7 +471,7 @@ func TestInitializationWaiting(t *testing.T) {
|
||||
select {
|
||||
case <-oidc.initComplete:
|
||||
// All waiters should see the same initialized state
|
||||
if oidc.tokenURL != "https://multi.example.com/token" {
|
||||
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
|
||||
t.Errorf("waiter %d: expected tokenURL to be set", id)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
|
||||
+68
-44
@@ -1875,14 +1875,14 @@ func TestHandleLogout(t *testing.T) {
|
||||
},
|
||||
endSessionURL: "",
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
expectedURL: "/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
name: "Logout with empty session",
|
||||
setupSession: func(session *SessionData) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
expectedURL: "/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
@@ -2349,19 +2349,22 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
t.Skip("Skipping test in short mode")
|
||||
}
|
||||
|
||||
// Create mock provider metadata server
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create mock provider metadata server. Issuer + endpoints must share the
|
||||
// host with ProviderURL (the httptest server), otherwise the discovery doc
|
||||
// is rejected as poisoned (audit ranks 21/22). Derive them from the server.
|
||||
var mockServer *httptest.Server
|
||||
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
RevokeURL: "https://test-issuer.com/revoke",
|
||||
EndSessionURL: "https://test-issuer.com/end-session",
|
||||
Issuer: mockServer.URL,
|
||||
AuthURL: mockServer.URL + "/auth",
|
||||
TokenURL: mockServer.URL + "/token",
|
||||
JWKSURL: mockServer.URL + "/jwks",
|
||||
RevokeURL: mockServer.URL + "/revoke",
|
||||
EndSessionURL: mockServer.URL + "/end-session",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
@@ -2374,6 +2377,7 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create multiple middleware instances
|
||||
@@ -2414,18 +2418,20 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
t.Fatalf("Middleware instance %d failed to initialize", i)
|
||||
}
|
||||
|
||||
// Verify each instance has its own unique configuration
|
||||
if m.issuerURL != "https://test-issuer.com" {
|
||||
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
|
||||
// Verify each instance has its own unique configuration. Issuer is now
|
||||
// pinned to the provider host (audit ranks 21/22), so it equals the
|
||||
// mock server URL rather than a fixed literal.
|
||||
if m.issuerURL != mockServer.URL {
|
||||
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, mockServer.URL, m.issuerURL)
|
||||
}
|
||||
if m.authURL != "https://test-issuer.com/auth" {
|
||||
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
|
||||
if m.authURL != mockServer.URL+"/auth" {
|
||||
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, mockServer.URL+"/auth", m.authURL)
|
||||
}
|
||||
if m.tokenURL != "https://test-issuer.com/token" {
|
||||
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
|
||||
if m.tokenURL != mockServer.URL+"/token" {
|
||||
t.Errorf("Instance %d: Expected token URL %s, got %s", i, mockServer.URL+"/token", m.tokenURL)
|
||||
}
|
||||
if m.jwksURL != "https://test-issuer.com/jwks" {
|
||||
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
|
||||
if m.jwksURL != mockServer.URL+"/jwks" {
|
||||
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, mockServer.URL+"/jwks", m.jwksURL)
|
||||
}
|
||||
if m.redirURLPath != routes[i]+"/callback" {
|
||||
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
|
||||
@@ -2439,15 +2445,16 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
|
||||
m.ServeHTTP(rr, req)
|
||||
|
||||
// Should redirect to auth URL since not authenticated
|
||||
// Should redirect (302) to the auth flow since not authenticated. The
|
||||
// absolute auth URL is not asserted here: with issuer pinning (audit
|
||||
// ranks 21/22) the discovery host equals the httptest server host,
|
||||
// which is loopback, so buildAuthURL's SSRF guard legitimately refuses
|
||||
// to emit a loopback authorization URL in this test environment. The
|
||||
// per-instance auth/token/jwks/issuer URLs were already verified above;
|
||||
// here we only confirm each instance independently triggers a redirect.
|
||||
if rr.Code != http.StatusFound {
|
||||
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
|
||||
}
|
||||
|
||||
location := rr.Header().Get("Location")
|
||||
if !strings.Contains(location, "https://test-issuer.com/auth") {
|
||||
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2460,33 +2467,43 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create two mock provider metadata servers simulating different Keycloak realms
|
||||
realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Issuer + endpoints must share the host with each realm's ProviderURL
|
||||
// (the httptest server), otherwise the discovery doc is rejected as
|
||||
// poisoned (audit ranks 21/22). Keep the distinguishing /realms/realmN
|
||||
// path so the per-realm isolation assertions below still hold, but base
|
||||
// the host on the server URL — which is exactly what a same-host Keycloak
|
||||
// deployment looks like.
|
||||
var realm1Server *httptest.Server
|
||||
realm1Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
base := realm1Server.URL + "/realms/realm1"
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://keycloak.example.com/realms/realm1",
|
||||
AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth",
|
||||
TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token",
|
||||
JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs",
|
||||
EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout",
|
||||
Issuer: base,
|
||||
AuthURL: base + "/protocol/openid-connect/auth",
|
||||
TokenURL: base + "/protocol/openid-connect/token",
|
||||
JWKSURL: base + "/protocol/openid-connect/certs",
|
||||
EndSessionURL: base + "/protocol/openid-connect/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
defer realm1Server.Close()
|
||||
|
||||
realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var realm2Server *httptest.Server
|
||||
realm2Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
base := realm2Server.URL + "/realms/realm2"
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://keycloak.example.com/realms/realm2",
|
||||
AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth",
|
||||
TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token",
|
||||
JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs",
|
||||
EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout",
|
||||
Issuer: base,
|
||||
AuthURL: base + "/protocol/openid-connect/auth",
|
||||
TokenURL: base + "/protocol/openid-connect/token",
|
||||
JWKSURL: base + "/protocol/openid-connect/certs",
|
||||
EndSessionURL: base + "/protocol/openid-connect/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
@@ -2500,6 +2517,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
CallbackURL: "/realm1/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
CookiePrefix: "_oidc_realm1_",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Config for realm2
|
||||
@@ -2510,6 +2528,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
CallbackURL: "/realm2/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
CookiePrefix: "_oidc_realm2_",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware instances for both realms
|
||||
@@ -2608,8 +2627,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
providerAvailable := false
|
||||
var mu sync.Mutex
|
||||
|
||||
// Create mock provider that initially fails, then becomes available
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create mock provider that initially fails, then becomes available.
|
||||
// Issuer + endpoints must share the host with ProviderURL (audit ranks
|
||||
// 21/22), so derive them from the server URL.
|
||||
var mockServer *httptest.Server
|
||||
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
available := providerAvailable
|
||||
mu.Unlock()
|
||||
@@ -2621,11 +2643,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
EndSessionURL: "https://test-issuer.com/logout",
|
||||
Issuer: mockServer.URL,
|
||||
AuthURL: mockServer.URL + "/auth",
|
||||
TokenURL: mockServer.URL + "/token",
|
||||
JWKSURL: mockServer.URL + "/jwks",
|
||||
EndSessionURL: mockServer.URL + "/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
return
|
||||
@@ -2640,6 +2662,7 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware while provider is unavailable
|
||||
@@ -4552,6 +4575,7 @@ func TestNewWithScopeAppending(t *testing.T) {
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
Scopes: tc.configScopes,
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
// Create middleware instance
|
||||
|
||||
@@ -1652,6 +1652,7 @@ func TestGoroutineLeaks(t *testing.T) {
|
||||
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
|
||||
config.ClientID = "test-client"
|
||||
config.ClientSecret = "test-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
|
||||
handler, err := New(context.Background(), nil, config, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
+11
-1
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -141,10 +142,19 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
|
||||
}
|
||||
|
||||
var metadata ProviderMetadata
|
||||
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode metadata: %w", err)
|
||||
}
|
||||
|
||||
// Pin the advertised issuer to the configured provider host. The issuer is
|
||||
// the trust anchor for JWT issuer validation; rejecting a mismatch here
|
||||
// ensures a poisoned discovery document advertising an attacker-chosen
|
||||
// issuer is never cached or returned. Real providers (Google, Azure,
|
||||
// Keycloak, Okta, Auth0) keep the issuer on the same host as providerURL.
|
||||
if metadata.Issuer != "" && !sameHost(metadata.Issuer, providerURL) {
|
||||
return nil, fmt.Errorf("discovery issuer %q host does not match provider %q", metadata.Issuer, providerURL)
|
||||
}
|
||||
|
||||
// Cache for 1 hour by default
|
||||
if err := mc.Set(providerURL, &metadata, 1*time.Hour); err != nil {
|
||||
mc.logger.Errorf("Failed to cache metadata: %v", err)
|
||||
|
||||
+80
-7
@@ -472,6 +472,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - req: The HTTP request to process.
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
//
|
||||
// processAuthorizedRequestRS is the requestState-aware variant of
|
||||
// processAuthorizedRequest. It reads SessionData fields from the captured
|
||||
// snapshot in rs instead of calling session.GetX() (each of which acquires
|
||||
@@ -675,6 +676,44 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
//
|
||||
// Session persistence is the CALLER's responsibility — it must happen before
|
||||
// this function so Set-Cookie reaches the response.
|
||||
// headerTemplateMaxLen bounds the length of a rendered operator-defined header
|
||||
// template before it is forwarded downstream. Generous enough for an
|
||||
// "Authorization: Bearer <jwt>" value but small enough to reject obviously
|
||||
// abusive output. Matches the input-validation default header cap (8KB).
|
||||
const headerTemplateMaxLen = 8192
|
||||
|
||||
// headerClaimMaxLen returns the maximum accepted length for a claim-derived
|
||||
// header value (principal identifier, group, role). Reuses the operator-
|
||||
// configured identifier cap (default 256) so a single setting governs both
|
||||
// auth paths; falls back to 256 when unset.
|
||||
func (t *TraefikOidc) headerClaimMaxLen() int {
|
||||
if t.maxIdentifierLength > 0 {
|
||||
return t.maxIdentifierLength
|
||||
}
|
||||
return 256
|
||||
}
|
||||
|
||||
// sanitizeHeaderClaimList drops any group/role value that fails claim
|
||||
// sanitization (control chars, bidi-override runes, the , ; = delimiters, or an
|
||||
// over-long value) and returns the surviving values. Failing closed on a bad
|
||||
// entry prevents header injection and stops an embedded comma from injecting
|
||||
// extra entries into the comma-joined header. headerName is used only for
|
||||
// debug logging — the value is never logged.
|
||||
func (t *TraefikOidc) sanitizeHeaderClaimList(values []string, headerName string) []string {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
safe := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
if clean, ok := sanitizeHeaderClaimValue(v, t.headerClaimMaxLen()); ok {
|
||||
safe = append(safe, clean)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping %s entry: value failed claim sanitization", headerName)
|
||||
}
|
||||
}
|
||||
return safe
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Request, p *principal) {
|
||||
var (
|
||||
groups, roles []string
|
||||
@@ -692,11 +731,18 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
|
||||
return
|
||||
}
|
||||
if extractErr == nil {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
// Sanitize each group/role before it is joined into a comma-
|
||||
// delimited header. The cookie/session path does not otherwise
|
||||
// sanitize claim-derived values (the bearer path sanitizes its
|
||||
// identifier at construction), so a control char would enable
|
||||
// header injection and an embedded comma would inject extra
|
||||
// entries into the comma-joined header. Fail closed: drop any
|
||||
// value that does not pass.
|
||||
if safeGroups := t.sanitizeHeaderClaimList(groups, "X-User-Groups"); len(safeGroups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(safeGroups, ","))
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
if safeRoles := t.sanitizeHeaderClaimList(roles, "X-User-Roles"); len(safeRoles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(safeRoles, ","))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -717,12 +763,26 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", p.Identifier)
|
||||
// Sanitize the principal identifier before injecting it into headers. The
|
||||
// bearer path already sanitizes its identifier at construction; the
|
||||
// cookie/session path does not, so a claim carrying control chars, bidi-
|
||||
// override runes, or , ; = could inject or spoof header content. Fail
|
||||
// closed: drop the identifier header(s) rather than forward a tainted value.
|
||||
safeIdentifier, identifierOK := sanitizeHeaderClaimValue(p.Identifier, t.headerClaimMaxLen())
|
||||
if identifierOK {
|
||||
req.Header.Set("X-Forwarded-User", safeIdentifier)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping X-Forwarded-User header: identifier failed claim sanitization")
|
||||
}
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", p.Identifier)
|
||||
if identifierOK {
|
||||
req.Header.Set("X-Auth-Request-User", safeIdentifier)
|
||||
} else {
|
||||
t.logger.Debugf("Dropping X-Auth-Request-User header: identifier failed claim sanitization")
|
||||
}
|
||||
if p.IDToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", p.IDToken)
|
||||
}
|
||||
@@ -747,8 +807,21 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
|
||||
continue
|
||||
}
|
||||
headerValue := buf.String()
|
||||
// Sanitize the rendered output: template inputs are claim-derived
|
||||
// and attacker-influenceable, so reject control chars (header
|
||||
// injection), bidi-override runes, the , ; = delimiters, and an
|
||||
// over-long value. Fail closed by dropping the header rather than
|
||||
// forwarding a tainted value. Do not log the value (it commonly
|
||||
// carries the access token); log only name + reason.
|
||||
if reason := headerValueReason(headerValue, headerTemplateMaxLen); reason != "" {
|
||||
t.logger.Debugf("Dropping templated header %s: value failed sanitization (%s)", headerName, reason)
|
||||
continue
|
||||
}
|
||||
req.Header.Set(headerName, headerValue)
|
||||
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
||||
// Do not log the value: templated headers commonly carry the access
|
||||
// token (e.g. "Authorization: Bearer {{.AccessToken}}"), and logging
|
||||
// it — even at debug — leaks credentials into logs.
|
||||
t.logger.Debugf("Set templated header %s (%d bytes)", headerName, len(headerValue))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -498,7 +498,7 @@ func (t *refreshAttemptTracker) mutateState(mutate func(cur *attemptState) *atte
|
||||
if next == nil {
|
||||
return cur
|
||||
}
|
||||
if t.state.CompareAndSwap(t.state.Load(), next) {
|
||||
if t.state.CompareAndSwap(cur, next) {
|
||||
return next
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// TestRank1_SessionCookieIsEncrypted verifies that the session cookie payload is
|
||||
// AES-encrypted, not merely HMAC-signed. Regression test for the audit finding
|
||||
// "session cookies signed but NOT encrypted": a single key left the stored OIDC
|
||||
// tokens recoverable in plaintext from the raw cookie bytes.
|
||||
func TestRank1_SessionCookieIsEncrypted(t *testing.T) {
|
||||
const secret = "a-sufficiently-long-session-encryption-key"
|
||||
authKey, encKey := deriveCookieKeys(secret)
|
||||
if len(authKey) != 64 || len(encKey) != 32 {
|
||||
t.Fatalf("expected 64-byte auth key and 32-byte enc key, got %d/%d", len(authKey), len(encKey))
|
||||
}
|
||||
if string(authKey) == string(encKey) {
|
||||
t.Fatal("authentication and encryption keys must be independent")
|
||||
}
|
||||
|
||||
const marker = "SUPER-SECRET-ACCESS-TOKEN-marker-value"
|
||||
|
||||
// Encode a session through the same two-key store the production code now
|
||||
// builds (see NewSessionManager).
|
||||
store := sessions.NewCookieStore(authKey, encKey)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
sess, err := store.New(req, "session")
|
||||
if err != nil {
|
||||
t.Fatalf("store.New failed: %v", err)
|
||||
}
|
||||
sess.Values["tok"] = marker
|
||||
if err := sess.Save(req, rec); err != nil {
|
||||
t.Fatalf("session save failed: %v", err)
|
||||
}
|
||||
|
||||
var cookie *http.Cookie
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
if c.Name == "session" {
|
||||
cookie = c
|
||||
}
|
||||
}
|
||||
if cookie == nil {
|
||||
t.Fatal("no session cookie was set")
|
||||
}
|
||||
|
||||
// The secret token must never appear in plaintext in the cookie value.
|
||||
if strings.Contains(cookie.Value, marker) {
|
||||
t.Error("marker token found in plaintext inside the session cookie value")
|
||||
}
|
||||
|
||||
// A store holding only the authentication key (the previous behavior)
|
||||
// must NOT be able to read the encrypted cookie — proving the payload is
|
||||
// genuinely encrypted, not just signed.
|
||||
signedOnly := sessions.NewCookieStore(authKey)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req2.AddCookie(cookie)
|
||||
if _, derr := signedOnly.Get(req2, "session"); derr == nil {
|
||||
t.Error("encrypted cookie should not be decodable without the encryption key")
|
||||
}
|
||||
|
||||
// The full two-key store round-trips correctly.
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req3.AddCookie(cookie)
|
||||
rt, derr := store.Get(req3, "session")
|
||||
if derr != nil {
|
||||
t.Fatalf("round-trip decode failed: %v", derr)
|
||||
}
|
||||
if got, _ := rt.Values["tok"].(string); got != marker {
|
||||
t.Errorf("round-trip mismatch: got %q want %q", got, marker)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank2And6_InvalidConfigFailsClosed verifies that NewWithContext now calls
|
||||
// Config.Validate() and fails closed on an empty or too-short session
|
||||
// encryption key instead of silently substituting a public hardcoded key, and
|
||||
// rejects other missing required fields. Regression test for "hardcoded default
|
||||
// encryption key" + "Config.Validate() never called in production path".
|
||||
func TestRank2And6_InvalidConfigFailsClosed(t *testing.T) {
|
||||
base := func() *Config {
|
||||
return &Config{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "this-is-a-valid-session-key-32b!",
|
||||
RateLimit: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity: a fully valid config still constructs.
|
||||
p, err := NewWithContext(context.Background(), base(), nil, "valid")
|
||||
if err != nil {
|
||||
t.Fatalf("valid config should construct, got: %v", err)
|
||||
}
|
||||
if p != nil {
|
||||
p.Close()
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mutate func(*Config)
|
||||
}{
|
||||
{"empty key", func(c *Config) { c.SessionEncryptionKey = "" }},
|
||||
{"short key", func(c *Config) { c.SessionEncryptionKey = "tooshort" }},
|
||||
{"missing providerURL", func(c *Config) { c.ProviderURL = "" }},
|
||||
{"missing callbackURL", func(c *Config) { c.CallbackURL = "" }},
|
||||
{"plaintext remote providerURL", func(c *Config) { c.ProviderURL = "http://accounts.google.com" }},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := base()
|
||||
tc.mutate(c)
|
||||
plugin, err := NewWithContext(context.Background(), c, nil, tc.name)
|
||||
if err == nil {
|
||||
if plugin != nil {
|
||||
plugin.Close()
|
||||
}
|
||||
t.Errorf("expected NewWithContext to reject config (%s), but it succeeded", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank3_DiscoveredEndpointSSRFGuard verifies that endpoints from the
|
||||
// provider discovery document are screened against SSRF targets before use.
|
||||
func TestRank3_DiscoveredEndpointSSRFGuard(t *testing.T) {
|
||||
tr := &TraefikOidc{}
|
||||
|
||||
blocked := []string{
|
||||
"http://169.254.169.254/latest/meta-data/", // cloud metadata (link-local)
|
||||
"http://[fe80::1]/jwks", // IPv6 link-local
|
||||
"http://10.0.0.5/jwks", // private
|
||||
"http://192.168.1.10/jwks", // private
|
||||
"http://127.0.0.1/jwks", // loopback (allowLoopback=false)
|
||||
"ftp://example.com/jwks", // disallowed scheme
|
||||
}
|
||||
for _, u := range blocked {
|
||||
if err := tr.validateDiscoveredEndpoint(u, false); err == nil {
|
||||
t.Errorf("expected discovered endpoint %q to be rejected", u)
|
||||
}
|
||||
}
|
||||
|
||||
allowed := []string{
|
||||
"https://accounts.google.com/o/oauth2/v3/certs",
|
||||
"https://www.googleapis.com/oauth2/v3/certs", // cross-domain JWKS must stay allowed
|
||||
"", // empty optional endpoint
|
||||
}
|
||||
for _, u := range allowed {
|
||||
if err := tr.validateDiscoveredEndpoint(u, false); err != nil {
|
||||
t.Errorf("expected discovered endpoint %q to be allowed, got %v", u, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Loopback is allowed only when the provider itself is loopback (dev/test).
|
||||
if err := tr.validateDiscoveredEndpoint("http://127.0.0.1:8080/jwks", true); err != nil {
|
||||
t.Errorf("loopback endpoint should be allowed when allowLoopback=true: %v", err)
|
||||
}
|
||||
// Private addresses are allowed when explicitly opted in.
|
||||
trPriv := &TraefikOidc{allowPrivateIPAddresses: true}
|
||||
if err := trPriv.validateDiscoveredEndpoint("http://10.0.0.5/jwks", false); err != nil {
|
||||
t.Errorf("private endpoint should be allowed when allowPrivateIPAddresses=true: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank4_IntrospectionHostPin verifies the host-equality check used to pin
|
||||
// the credential-bearing introspection endpoint to the configured provider.
|
||||
func TestRank4_IntrospectionHostPin(t *testing.T) {
|
||||
if !sameHost("https://kc.example.com/realms/x", "https://kc.example.com/realms/x/protocol/openid-connect/token/introspect") {
|
||||
t.Error("introspection on the same host as the provider should be accepted")
|
||||
}
|
||||
if sameHost("https://kc.example.com", "https://evil.example.net/introspect") {
|
||||
t.Error("introspection on a different host must be rejected")
|
||||
}
|
||||
if sameHost("", "https://kc.example.com") || sameHost("https://kc.example.com", "") {
|
||||
t.Error("empty URL must not be treated as a host match")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank5_OpenRedirectNeutralized verifies the helper the callback now applies
|
||||
// to the stored incoming path forces a host-relative redirect target.
|
||||
func TestRank5_OpenRedirectNeutralized(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"//evil.com/x": "/evil.com/x",
|
||||
`/\evil.com`: "/evil.com",
|
||||
"/legit/path": "/legit/path",
|
||||
}
|
||||
for in, want := range cases {
|
||||
got := normalizeLogoutPath(in)
|
||||
if got != want {
|
||||
t.Errorf("normalizeLogoutPath(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
if strings.HasPrefix(got, "//") || strings.HasPrefix(got, `/\`) {
|
||||
t.Errorf("normalizeLogoutPath(%q) = %q is still protocol-relative", in, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank14_ExcludedURLSegmentBoundary verifies excluded-URL matching is
|
||||
// anchored at path-segment boundaries and cannot be widened into a bypass.
|
||||
func TestRank14_ExcludedURLSegmentBoundary(t *testing.T) {
|
||||
if !pathExcluded("/public", "/public") {
|
||||
t.Error("exact match should be excluded")
|
||||
}
|
||||
if !pathExcluded("/public/page", "/public") {
|
||||
t.Error("sub-path should be excluded")
|
||||
}
|
||||
if pathExcluded("/publicsecret", "/public") {
|
||||
t.Error("/publicsecret must NOT be excluded by /public")
|
||||
}
|
||||
if pathExcluded("/public-admin", "/public") {
|
||||
t.Error("/public-admin must NOT be excluded by /public")
|
||||
}
|
||||
if !pathExcluded("/health", "/health/") {
|
||||
t.Error("trailing-slash config should still match the exact path")
|
||||
}
|
||||
if pathExcluded("/anything", "/") {
|
||||
t.Error("root exclusion must not match arbitrary paths")
|
||||
}
|
||||
if !pathExcluded("/", "/") {
|
||||
t.Error("root exclusion should match the root path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank15_ForwardedHostSanitized verifies a crafted X-Forwarded-Host cannot
|
||||
// inject CRLF, smuggle a second host, or otherwise poison the derived host.
|
||||
func TestRank15_ForwardedHostSanitized(t *testing.T) {
|
||||
mk := func(xfh string) *http.Request {
|
||||
r := httptest.NewRequest(http.MethodGet, "http://real.example.com/x", nil)
|
||||
r.Host = "real.example.com"
|
||||
if xfh != "" {
|
||||
r.Header.Set("X-Forwarded-Host", xfh)
|
||||
}
|
||||
return r
|
||||
}
|
||||
if got := utils.DetermineHost(mk("ext.example.com")); got != "ext.example.com" {
|
||||
t.Errorf("clean X-Forwarded-Host should be honored, got %q", got)
|
||||
}
|
||||
if got := utils.DetermineHost(mk("a.example.com, evil.com")); got != "a.example.com" {
|
||||
t.Errorf("multi-value X-Forwarded-Host should use first host only, got %q", got)
|
||||
}
|
||||
for _, bad := range []string{"evil.com\r\nSet-Cookie: x=1", "evil.com /x", " "} {
|
||||
if got := utils.DetermineHost(mk(bad)); got != "real.example.com" {
|
||||
t.Errorf("malformed X-Forwarded-Host %q should fall back to req.Host, got %q", bad, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank11_TransportPoolTLSIsolationAtLimit verifies that, once the client
|
||||
// limit is reached, the transport pool reuses an existing transport only when
|
||||
// its TLS settings match the caller's, and never hands back a transport built
|
||||
// with different TLS trust settings.
|
||||
func TestRank11_TransportPoolTLSIsolationAtLimit(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
strict := DefaultHTTPClientConfig() // InsecureSkipVerify = false
|
||||
t1 := pool.GetOrCreateTransport(strict)
|
||||
if t1 == nil {
|
||||
t.Fatal("expected a transport for the strict config")
|
||||
}
|
||||
|
||||
// Saturate the client limit so subsequent calls hit the fallback path.
|
||||
atomic.StoreInt32(&pool.clientCount, pool.maxClients)
|
||||
|
||||
// Same TLS settings, different (non-TLS) connection limit: safe to reuse.
|
||||
sameTLS := DefaultHTTPClientConfig()
|
||||
sameTLS.MaxConnsPerHost = 99
|
||||
if got := pool.GetOrCreateTransport(sameTLS); got != t1 {
|
||||
t.Error("at the limit a TLS-compatible config should reuse the existing transport")
|
||||
}
|
||||
|
||||
// Different TLS settings (InsecureSkipVerify): must NOT reuse the strict
|
||||
// transport — returning nil lets the caller fall back to a verifying default.
|
||||
insecure := DefaultHTTPClientConfig()
|
||||
insecure.InsecureSkipVerify = true
|
||||
if got := pool.GetOrCreateTransport(insecure); got == t1 {
|
||||
t.Error("at the limit a config with different TLS settings must not reuse the strict transport")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank9_RedisFingerprint verifies divergent explicit Redis backends produce
|
||||
// distinct fingerprints (used to warn about ignored cache config), while an
|
||||
// absent or disabled Redis yields the empty (no-warning) fingerprint.
|
||||
func TestRank9_RedisFingerprint(t *testing.T) {
|
||||
if redisFingerprint(nil) != "" {
|
||||
t.Error("nil config should yield an empty fingerprint")
|
||||
}
|
||||
if redisFingerprint(&Config{}) != "" {
|
||||
t.Error("config without Redis should yield an empty fingerprint")
|
||||
}
|
||||
if redisFingerprint(&Config{Redis: &RedisConfig{Enabled: false, Address: "a:6379"}}) != "" {
|
||||
t.Error("disabled Redis should yield an empty fingerprint")
|
||||
}
|
||||
a := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "a:6379", KeyPrefix: "p"}})
|
||||
b := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "b:6379", KeyPrefix: "p"}})
|
||||
if a == "" || a == b {
|
||||
t.Errorf("distinct enabled backends must produce distinct non-empty fingerprints (%q vs %q)", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank10_TokenTypeCacheKeyNoCollision verifies that two different tokens
|
||||
// sharing the same 32-character JWT header prefix are classified independently.
|
||||
// The previous 32-char cache key would have collided and mis-classified them.
|
||||
func TestRank10_TokenTypeCacheKeyNoCollision(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
tokenTypeCache: NewCache(),
|
||||
suppressDiagnosticLogs: true,
|
||||
clientID: "client",
|
||||
}
|
||||
// A header prefix longer than 32 chars, shared by both tokens.
|
||||
prefix := "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ"
|
||||
idJWT := &JWT{Header: map[string]interface{}{}, Claims: map[string]interface{}{"nonce": "n"}}
|
||||
accessJWT := &JWT{Header: map[string]interface{}{"typ": "at+jwt"}, Claims: map[string]interface{}{}}
|
||||
|
||||
if !tr.detectTokenType(idJWT, prefix+".id.sig") {
|
||||
t.Error("token with a nonce claim should be detected as an ID token")
|
||||
}
|
||||
if tr.detectTokenType(accessJWT, prefix+".access.sig") {
|
||||
t.Error("access token (typ=at+jwt) must not be mis-classified as ID despite the shared 32-char prefix")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank12_LiveInstanceCounter verifies the process-global instance counter
|
||||
// that gates teardown of shared singleton tasks.
|
||||
func TestRank12_LiveInstanceCounter(t *testing.T) {
|
||||
start := atomic.LoadInt32(&liveInstanceCount)
|
||||
registerLiveInstance()
|
||||
registerLiveInstance()
|
||||
if got := atomic.LoadInt32(&liveInstanceCount); got != start+2 {
|
||||
t.Fatalf("expected %d live instances, got %d", start+2, got)
|
||||
}
|
||||
if rem := unregisterLiveInstance(); rem != start+1 {
|
||||
t.Errorf("expected %d remaining, got %d", start+1, rem)
|
||||
}
|
||||
if rem := unregisterLiveInstance(); rem != start {
|
||||
t.Errorf("expected %d remaining, got %d", start, rem)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank13_CookieMaxAgeMatchesSessionLifetime verifies the cookie store's
|
||||
// MaxAge (which bounds both the cookie Max-Age and the codec's cryptographic
|
||||
// timestamp validity) is bound to the configured session lifetime rather than
|
||||
// gorilla's 30-day default.
|
||||
func TestRank13_CookieMaxAgeMatchesSessionLifetime(t *testing.T) {
|
||||
maxAge := 2 * time.Hour
|
||||
sm, err := NewSessionManager(strings.Repeat("k", 40), false, "", "", maxAge, NewLogger("error"))
|
||||
if err != nil {
|
||||
t.Fatalf("NewSessionManager failed: %v", err)
|
||||
}
|
||||
defer sm.cancel()
|
||||
|
||||
cs, ok := sm.store.(*sessions.CookieStore)
|
||||
if !ok {
|
||||
t.Fatal("session store is not a *sessions.CookieStore")
|
||||
}
|
||||
if got := cs.Options.MaxAge; got != int(maxAge.Seconds()) {
|
||||
t.Errorf("cookie store MaxAge = %d, want %d (bound to sessionMaxAge)", got, int(maxAge.Seconds()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestRank33And34_HeaderSanitizationDistinction verifies the two header sinks
|
||||
// use the right strictness: free-form templated header VALUES (rank 34) permit
|
||||
// , ; = (e.g. an opaque "Bearer <token>" or an LDAP-DN claim) but reject CR/LF,
|
||||
// bidi, and over-length; claim values joined into delimited/identifier headers
|
||||
// (rank 33) additionally reject , ; =.
|
||||
func TestRank33And34_HeaderSanitizationDistinction(t *testing.T) {
|
||||
// Rank 34 — free-form header value.
|
||||
if headerValueReason("Bearer abc=def==", 8192) != "" {
|
||||
t.Error("'=' must be allowed in a free-form header value (opaque bearer token)")
|
||||
}
|
||||
if headerValueReason("cn=user,ou=eng;dc=x", 8192) != "" {
|
||||
t.Error("',;=' must be allowed in a free-form header value (e.g. an LDAP DN claim)")
|
||||
}
|
||||
if headerValueReason("evil"+string(rune(13))+string(rune(10))+"Injected: 1", 8192) == "" {
|
||||
t.Error("CR/LF must be rejected in a header value (injection)")
|
||||
}
|
||||
if headerValueReason("toolong", 3) == "" {
|
||||
t.Error("over-length value must be rejected")
|
||||
}
|
||||
|
||||
// Rank 33 — claim value bound for a delimited/identifier header.
|
||||
if _, ok := sanitizeHeaderClaimValue("admins,superadmins", 256); ok {
|
||||
t.Error("a comma must be rejected in a value joined into a comma-delimited header")
|
||||
}
|
||||
if _, ok := sanitizeHeaderClaimValue("normal-user@example.com", 256); !ok {
|
||||
t.Error("a clean identifier must pass claim sanitization")
|
||||
}
|
||||
if _, ok := sanitizeHeaderClaimValue("evil"+string(rune(13))+string(rune(10))+"X: 1", 256); ok {
|
||||
t.Error("CR/LF must be rejected in a claim value")
|
||||
}
|
||||
}
|
||||
@@ -1,590 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecurityEventType categorizes different types of security events
|
||||
// that can occur during OIDC authentication and authorization flows.
|
||||
type SecurityEventType string
|
||||
|
||||
// Security event types for monitoring and alerting
|
||||
const (
|
||||
// AuthFailure indicates a failed authentication attempt
|
||||
AuthFailure SecurityEventType = "authentication_failure"
|
||||
// TokenValidFailure indicates JWT token validation failed
|
||||
TokenValidFailure SecurityEventType = "token_validation_failure"
|
||||
// RateLimitHit indicates rate limiting was triggered
|
||||
RateLimitHit SecurityEventType = "rate_limit_hit"
|
||||
// SuspiciousActivity indicates potentially malicious behavior
|
||||
SuspiciousActivity SecurityEventType = "suspicious_activity"
|
||||
)
|
||||
|
||||
// DefaultSeverity returns the default severity level for each security event type.
|
||||
// Severity levels are: low, medium, high.
|
||||
func (t SecurityEventType) DefaultSeverity() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "medium"
|
||||
case TokenValidFailure:
|
||||
return "medium"
|
||||
case RateLimitHit:
|
||||
return "low"
|
||||
case SuspiciousActivity:
|
||||
return "high"
|
||||
default:
|
||||
return "medium"
|
||||
}
|
||||
}
|
||||
|
||||
// IPFailureType returns a string identifier for categorizing failures
|
||||
// by IP address for rate limiting and blocking decisions.
|
||||
func (t SecurityEventType) IPFailureType() string {
|
||||
switch t {
|
||||
case AuthFailure:
|
||||
return "auth_failure"
|
||||
case TokenValidFailure:
|
||||
return "token_failure"
|
||||
case SuspiciousActivity:
|
||||
return "suspicious"
|
||||
default:
|
||||
return "general"
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityEvent represents a security-related event with comprehensive context.
|
||||
// Contains timing information, IP address, user agent, request details,
|
||||
// and custom event-specific data for security analysis and alerting.
|
||||
type SecurityEvent struct {
|
||||
// Timestamp when the event occurred
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
// Details contains event-specific additional information
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
// Type categorizes the event (auth_failure, token_failure, etc.)
|
||||
Type string `json:"type"`
|
||||
// Severity indicates event importance (low, medium, high)
|
||||
Severity string `json:"severity"`
|
||||
// ClientIP is the source IP address of the request
|
||||
ClientIP string `json:"client_ip"`
|
||||
// UserAgent is the User-Agent header from the request
|
||||
UserAgent string `json:"user_agent"`
|
||||
// RequestPath is the requested URL path
|
||||
RequestPath string `json:"request_path"`
|
||||
// Message provides human-readable description of the event
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SecurityMonitor provides comprehensive security monitoring for the OIDC middleware.
|
||||
// It tracks failures by IP address, detects suspicious patterns, enforces
|
||||
// rate limits, and can trigger custom security event handlers.
|
||||
type SecurityMonitor struct {
|
||||
ipFailures map[string]*IPFailureTracker
|
||||
patternDetector *SuspiciousPatternDetector
|
||||
logger *Logger
|
||||
cleanupTask *BackgroundTask
|
||||
eventHandlers []SecurityEventHandler
|
||||
config SecurityMonitorConfig
|
||||
ipMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// IPFailureTracker maintains failure statistics and blocking state for an IP address.
|
||||
// Used for implementing progressive penalties and automatic IP blocking based on
|
||||
// failure patterns, with support for different failure types for
|
||||
// rate limiting and IP blocking decisions.
|
||||
type IPFailureTracker struct {
|
||||
// LastFailure timestamp of the most recent failure
|
||||
LastFailure time.Time
|
||||
// FirstFailure timestamp of the first failure in current window
|
||||
FirstFailure time.Time
|
||||
// BlockedUntil indicates when the IP block expires
|
||||
BlockedUntil time.Time
|
||||
// FailureTypes tracks counts by failure type
|
||||
FailureTypes map[string]int64
|
||||
// FailureCount total number of failures
|
||||
FailureCount int64
|
||||
// mutex protects concurrent access to tracker data
|
||||
mutex sync.RWMutex
|
||||
// IsBlocked indicates if this IP is currently blocked
|
||||
IsBlocked bool
|
||||
}
|
||||
|
||||
// SuspiciousPatternDetector identifies attack patterns that may indicate coordinated threats.
|
||||
// Analyzes events across multiple time windows to detect rapid failures, distributed attacks,
|
||||
// and persistent attack patterns that individual IP monitoring might miss.
|
||||
type SuspiciousPatternDetector struct {
|
||||
// recentEvents stores recent security events for analysis
|
||||
recentEvents []SecurityEvent
|
||||
// shortWindow defines time frame for rapid failure detection
|
||||
shortWindow time.Duration
|
||||
// mediumWindow defines time frame for distributed attack detection
|
||||
mediumWindow time.Duration
|
||||
// longWindow defines time frame for persistent attack detection
|
||||
longWindow time.Duration
|
||||
// rapidFailureThreshold triggers rapid failure alerts
|
||||
rapidFailureThreshold int
|
||||
// distributedAttackThreshold triggers distributed attack alerts
|
||||
distributedAttackThreshold int
|
||||
// persistentAttackThreshold triggers persistent attack alerts
|
||||
persistentAttackThreshold int
|
||||
// eventsMutex protects concurrent access to events
|
||||
eventsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// SecurityEventHandler defines the interface for processing security events.
|
||||
// Implementations can log events, send alerts, update external systems,
|
||||
// or trigger automated response actions.
|
||||
type SecurityEventHandler interface {
|
||||
// HandleSecurityEvent processes a security event
|
||||
HandleSecurityEvent(event SecurityEvent)
|
||||
}
|
||||
|
||||
// SecurityMonitorConfig contains configuration parameters for the security monitor.
|
||||
// Controls thresholds, time windows, and behavior for security monitoring.
|
||||
type SecurityMonitorConfig struct {
|
||||
// MaxFailuresPerIP sets the failure threshold before blocking
|
||||
MaxFailuresPerIP int `json:"max_failures_per_ip"`
|
||||
// FailureWindowMinutes defines the time window for counting failures
|
||||
FailureWindowMinutes int `json:"failure_window_minutes"`
|
||||
// BlockDurationMinutes sets how long to block an IP
|
||||
BlockDurationMinutes int `json:"block_duration_minutes"`
|
||||
// RapidFailureThreshold triggers rapid failure detection
|
||||
RapidFailureThreshold int `json:"rapid_failure_threshold"`
|
||||
// CleanupIntervalMinutes sets cleanup frequency for old data
|
||||
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
|
||||
RetentionHours int `json:"retention_hours"`
|
||||
EnablePatternDetection bool `json:"enable_pattern_detection"`
|
||||
EnableDetailedLogging bool `json:"enable_detailed_logging"`
|
||||
LogSuspiciousOnly bool `json:"log_suspicious_only"`
|
||||
}
|
||||
|
||||
// DefaultSecurityMonitorConfig returns a default configuration
|
||||
func DefaultSecurityMonitorConfig() SecurityMonitorConfig {
|
||||
return SecurityMonitorConfig{
|
||||
MaxFailuresPerIP: 10,
|
||||
FailureWindowMinutes: 15,
|
||||
BlockDurationMinutes: 60,
|
||||
EnablePatternDetection: true,
|
||||
RapidFailureThreshold: 5,
|
||||
EnableDetailedLogging: true,
|
||||
LogSuspiciousOnly: false,
|
||||
CleanupIntervalMinutes: 30,
|
||||
RetentionHours: 24,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSecurityMonitor creates a new security monitor instance
|
||||
func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor {
|
||||
sm := &SecurityMonitor{
|
||||
ipFailures: make(map[string]*IPFailureTracker),
|
||||
eventHandlers: make([]SecurityEventHandler, 0),
|
||||
config: config,
|
||||
logger: logger,
|
||||
patternDetector: NewSuspiciousPatternDetector(),
|
||||
}
|
||||
|
||||
sm.startCleanupRoutine()
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// NewSuspiciousPatternDetector creates a new pattern detector
|
||||
func NewSuspiciousPatternDetector() *SuspiciousPatternDetector {
|
||||
return &SuspiciousPatternDetector{
|
||||
shortWindow: 1 * time.Minute,
|
||||
mediumWindow: 5 * time.Minute,
|
||||
longWindow: 15 * time.Minute,
|
||||
rapidFailureThreshold: 5,
|
||||
distributedAttackThreshold: 20,
|
||||
persistentAttackThreshold: 50,
|
||||
recentEvents: make([]SecurityEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSecurityEvent is a generic method to record any type of security event
|
||||
func (sm *SecurityMonitor) RecordSecurityEvent(
|
||||
eventType SecurityEventType,
|
||||
clientIP, userAgent, requestPath string,
|
||||
message string,
|
||||
details map[string]interface{},
|
||||
trackIPFailure bool) {
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: string(eventType),
|
||||
Severity: eventType.DefaultSeverity(),
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
UserAgent: userAgent,
|
||||
RequestPath: requestPath,
|
||||
Message: message,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
if trackIPFailure {
|
||||
sm.recordIPFailure(clientIP, eventType.IPFailureType())
|
||||
}
|
||||
|
||||
sm.processSecurityEvent(event)
|
||||
}
|
||||
|
||||
// RecordAuthenticationFailure records an authentication failure event
|
||||
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
|
||||
if details == nil {
|
||||
details = make(map[string]interface{})
|
||||
}
|
||||
details["reason"] = reason
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
AuthFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Authentication failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordTokenValidationFailure records a token validation failure
|
||||
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
|
||||
details := map[string]interface{}{
|
||||
"reason": reason,
|
||||
}
|
||||
if tokenPrefix != "" {
|
||||
details["token_prefix"] = tokenPrefix
|
||||
}
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
TokenValidFailure,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Token validation failed: %s", reason),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordRateLimitHit records when rate limiting is triggered
|
||||
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
|
||||
details := map[string]interface{}{
|
||||
"limit_type": "token_verification",
|
||||
}
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
RateLimitHit,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
"Rate limit exceeded",
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
|
||||
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) {
|
||||
if details == nil {
|
||||
details = make(map[string]interface{})
|
||||
}
|
||||
details["activity_type"] = activityType
|
||||
|
||||
sm.RecordSecurityEvent(
|
||||
SuspiciousActivity,
|
||||
clientIP,
|
||||
userAgent,
|
||||
requestPath,
|
||||
fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
|
||||
details,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
// recordIPFailure tracks failures for a specific IP address
|
||||
func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) {
|
||||
sm.ipMutex.Lock()
|
||||
defer sm.ipMutex.Unlock()
|
||||
|
||||
tracker, exists := sm.ipFailures[clientIP]
|
||||
if !exists {
|
||||
tracker = &IPFailureTracker{
|
||||
FailureTypes: make(map[string]int64),
|
||||
FirstFailure: time.Now(),
|
||||
}
|
||||
sm.ipFailures[clientIP] = tracker
|
||||
}
|
||||
|
||||
tracker.mutex.Lock()
|
||||
defer tracker.mutex.Unlock()
|
||||
|
||||
tracker.FailureCount++
|
||||
tracker.LastFailure = time.Now()
|
||||
tracker.FailureTypes[failureType]++
|
||||
|
||||
windowStart := time.Now().Add(-time.Duration(sm.config.FailureWindowMinutes) * time.Minute)
|
||||
if tracker.FirstFailure.After(windowStart) && tracker.FailureCount >= int64(sm.config.MaxFailuresPerIP) {
|
||||
if !tracker.IsBlocked {
|
||||
tracker.IsBlocked = true
|
||||
tracker.BlockedUntil = time.Now().Add(time.Duration(sm.config.BlockDurationMinutes) * time.Minute)
|
||||
|
||||
sm.logger.Errorf("IP %s blocked due to %d failures (types: %v)", clientIP, tracker.FailureCount, tracker.FailureTypes)
|
||||
|
||||
blockEvent := SecurityEvent{
|
||||
Type: "ip_blocked",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
ClientIP: clientIP,
|
||||
Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes),
|
||||
Details: map[string]interface{}{
|
||||
"failure_count": tracker.FailureCount,
|
||||
"failure_types": tracker.FailureTypes,
|
||||
"blocked_until": tracker.BlockedUntil,
|
||||
},
|
||||
}
|
||||
sm.processSecurityEvent(blockEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsIPBlocked checks if an IP address is currently blocked
|
||||
func (sm *SecurityMonitor) IsIPBlocked(clientIP string) bool {
|
||||
sm.ipMutex.RLock()
|
||||
defer sm.ipMutex.RUnlock()
|
||||
|
||||
tracker, exists := sm.ipFailures[clientIP]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
defer tracker.mutex.RUnlock()
|
||||
|
||||
if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) {
|
||||
return true
|
||||
}
|
||||
|
||||
if tracker.IsBlocked && time.Now().After(tracker.BlockedUntil) {
|
||||
tracker.IsBlocked = false
|
||||
sm.logger.Infof("IP %s automatically unblocked", clientIP)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// processSecurityEvent processes a security event through all handlers and pattern detection
|
||||
func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) {
|
||||
if sm.config.EnablePatternDetection {
|
||||
sm.patternDetector.AddEvent(event)
|
||||
|
||||
if patterns := sm.patternDetector.DetectSuspiciousPatterns(); len(patterns) > 0 {
|
||||
if len(patterns) == 1 {
|
||||
sm.logger.Errorf("Suspicious pattern detected: %s", patterns[0])
|
||||
} else {
|
||||
sm.logger.Errorf("Multiple suspicious patterns detected: %v", patterns)
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
patternEvent := SecurityEvent{
|
||||
Type: "suspicious_pattern",
|
||||
Severity: "high",
|
||||
Timestamp: time.Now(),
|
||||
Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern),
|
||||
Details: map[string]interface{}{
|
||||
"pattern_type": pattern,
|
||||
"trigger_event": event,
|
||||
},
|
||||
}
|
||||
sm.handleSecurityEvent(patternEvent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sm.handleSecurityEvent(event)
|
||||
}
|
||||
|
||||
// handleSecurityEvent sends the event to all registered handlers
|
||||
func (sm *SecurityMonitor) handleSecurityEvent(event SecurityEvent) {
|
||||
if sm.config.EnableDetailedLogging && (!sm.config.LogSuspiciousOnly || event.Severity == "high") {
|
||||
sm.logger.Infof("Security Event [%s/%s]: %s (IP: %s, Path: %s)",
|
||||
event.Type, event.Severity, event.Message, event.ClientIP, event.RequestPath)
|
||||
}
|
||||
|
||||
for _, handler := range sm.eventHandlers {
|
||||
go handler.HandleSecurityEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// AddEventHandler adds a security event handler
|
||||
func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
|
||||
sm.eventHandlers = append(sm.eventHandlers, handler)
|
||||
}
|
||||
|
||||
// This is kept for API compatibility but doesn't collect actual metrics
|
||||
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"tracked_ips": 0,
|
||||
}
|
||||
}
|
||||
|
||||
// AddEvent adds an event to the pattern detector
|
||||
func (spd *SuspiciousPatternDetector) AddEvent(event SecurityEvent) {
|
||||
spd.eventsMutex.Lock()
|
||||
defer spd.eventsMutex.Unlock()
|
||||
|
||||
spd.recentEvents = append(spd.recentEvents, event)
|
||||
|
||||
cutoff := time.Now().Add(-spd.longWindow)
|
||||
var filteredEvents []SecurityEvent
|
||||
for _, e := range spd.recentEvents {
|
||||
if e.Timestamp.After(cutoff) {
|
||||
filteredEvents = append(filteredEvents, e)
|
||||
}
|
||||
}
|
||||
spd.recentEvents = filteredEvents
|
||||
}
|
||||
|
||||
// DetectSuspiciousPatterns analyzes recent events for suspicious patterns
|
||||
func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string {
|
||||
spd.eventsMutex.RLock()
|
||||
defer spd.eventsMutex.RUnlock()
|
||||
|
||||
var patterns []string
|
||||
now := time.Now()
|
||||
|
||||
ipCounts := make(map[string]int)
|
||||
shortWindowStart := now.Add(-spd.shortWindow)
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(shortWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
ipCounts[event.ClientIP]++
|
||||
}
|
||||
}
|
||||
|
||||
for ip, count := range ipCounts {
|
||||
if count >= spd.rapidFailureThreshold {
|
||||
patterns = append(patterns, fmt.Sprintf("rapid_failures_from_ip_%s", ip))
|
||||
}
|
||||
}
|
||||
|
||||
mediumWindowStart := now.Add(-spd.mediumWindow)
|
||||
uniqueFailingIPs := make(map[string]bool)
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(mediumWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
uniqueFailingIPs[event.ClientIP] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(uniqueFailingIPs) >= spd.distributedAttackThreshold {
|
||||
patterns = append(patterns, "distributed_attack_pattern")
|
||||
}
|
||||
|
||||
longWindowStart := now.Add(-spd.longWindow)
|
||||
persistentFailures := 0
|
||||
|
||||
for _, event := range spd.recentEvents {
|
||||
if event.Timestamp.After(longWindowStart) &&
|
||||
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
|
||||
persistentFailures++
|
||||
}
|
||||
}
|
||||
|
||||
if persistentFailures >= spd.persistentAttackThreshold {
|
||||
patterns = append(patterns, "persistent_attack_pattern")
|
||||
}
|
||||
|
||||
return patterns
|
||||
}
|
||||
|
||||
// startCleanupRoutine starts the background cleanup routine
|
||||
func (sm *SecurityMonitor) startCleanupRoutine() {
|
||||
sm.cleanupTask = NewBackgroundTask(
|
||||
"security-monitor-cleanup",
|
||||
time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute,
|
||||
sm.cleanup,
|
||||
sm.logger)
|
||||
sm.cleanupTask.Start()
|
||||
}
|
||||
|
||||
// StopCleanupRoutine stops the background cleanup routine
|
||||
func (sm *SecurityMonitor) StopCleanupRoutine() {
|
||||
if sm.cleanupTask != nil {
|
||||
sm.cleanupTask.Stop()
|
||||
sm.cleanupTask = nil
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes old tracking data
|
||||
func (sm *SecurityMonitor) cleanup() {
|
||||
sm.ipMutex.Lock()
|
||||
defer sm.ipMutex.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-time.Duration(sm.config.RetentionHours) * time.Hour)
|
||||
|
||||
for ip, tracker := range sm.ipFailures {
|
||||
tracker.mutex.RLock()
|
||||
shouldRemove := tracker.LastFailure.Before(cutoff) && !tracker.IsBlocked
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
if shouldRemove {
|
||||
delete(sm.ipFailures, ip)
|
||||
}
|
||||
}
|
||||
|
||||
sm.logger.Debugf("Security monitor cleanup completed, tracking %d IPs", len(sm.ipFailures))
|
||||
}
|
||||
|
||||
// ExtractClientIP extracts the client IP from the request, considering proxy headers
|
||||
func ExtractClientIP(r *http.Request) string {
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
if net.ParseIP(xri) != nil {
|
||||
return xri
|
||||
}
|
||||
}
|
||||
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
ip := strings.TrimSpace(ips[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// LoggingSecurityEventHandler logs security events to the standard logger
|
||||
type LoggingSecurityEventHandler struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewLoggingSecurityEventHandler creates a new logging event handler
|
||||
func NewLoggingSecurityEventHandler(logger *Logger) *LoggingSecurityEventHandler {
|
||||
return &LoggingSecurityEventHandler{logger: logger}
|
||||
}
|
||||
|
||||
// HandleSecurityEvent implements SecurityEventHandler
|
||||
func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
switch event.Severity {
|
||||
case "high":
|
||||
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
case "medium":
|
||||
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
case "low":
|
||||
h.logger.Infof("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
default:
|
||||
h.logger.Debugf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
|
||||
}
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSecurityMonitor(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
config.MaxFailuresPerIP = 3
|
||||
config.BlockDurationMinutes = 1 // 1 minute for testing
|
||||
config.CleanupIntervalMinutes = 1
|
||||
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
defer func() {
|
||||
// Allow cleanup goroutine to finish
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}()
|
||||
|
||||
t.Run("Record authentication failure", func(t *testing.T) {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "invalid credentials", nil)
|
||||
|
||||
// Should not be blocked after first failure
|
||||
if monitor.IsIPBlocked("192.168.1.1") {
|
||||
t.Error("IP should not be blocked after first failure")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IP blocked after max failures", func(t *testing.T) {
|
||||
// Record multiple failures
|
||||
for i := 0; i < config.MaxFailuresPerIP; i++ {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.2", "test-agent", "/login", "invalid credentials", nil)
|
||||
}
|
||||
|
||||
// Should be blocked now
|
||||
if !monitor.IsIPBlocked("192.168.1.2") {
|
||||
t.Error("IP should be blocked after max failures")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Token validation failure", func(t *testing.T) {
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordTokenValidationFailure("192.168.1.3", "test-agent", "/api", "invalid token", "abc123")
|
||||
})
|
||||
|
||||
t.Run("Rate limit hit", func(t *testing.T) {
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordRateLimitHit("192.168.1.4", "test-agent", "/api")
|
||||
})
|
||||
|
||||
t.Run("Suspicious activity", func(t *testing.T) {
|
||||
details := map[string]interface{}{"pattern": "unusual"}
|
||||
// Just verify the method doesn't panic
|
||||
monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSuspiciousPatternDetector(t *testing.T) {
|
||||
detector := NewSuspiciousPatternDetector()
|
||||
|
||||
t.Run("Add events and detect patterns", func(t *testing.T) {
|
||||
// Add multiple events from same IP
|
||||
for i := 0; i < 10; i++ {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.100",
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
detector.AddEvent(event)
|
||||
}
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := false
|
||||
for _, p := range patterns {
|
||||
if p == "rapid_failures_from_ip_192.168.1.100" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected to detect rapid failure pattern")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Detect distributed attack pattern", func(t *testing.T) {
|
||||
// Add failures from many different IPs
|
||||
for i := 0; i < 25; i++ {
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1." + strconv.Itoa(100+i),
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
detector.AddEvent(event)
|
||||
}
|
||||
|
||||
patterns := detector.DetectSuspiciousPatterns()
|
||||
|
||||
found := false
|
||||
for _, p := range patterns {
|
||||
if p == "distributed_attack_pattern" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected to detect distributed attack pattern")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
headers map[string]string
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "Direct connection",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For header",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 10.0.0.1"},
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP header",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{"X-Real-IP": "203.0.113.2"},
|
||||
expectedIP: "203.0.113.2",
|
||||
},
|
||||
{
|
||||
name: "Multiple headers - X-Real-IP takes precedence",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "203.0.113.1",
|
||||
"X-Real-IP": "203.0.113.2",
|
||||
},
|
||||
expectedIP: "203.0.113.2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
|
||||
for key, value := range tt.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
ip := ExtractClientIP(req)
|
||||
if ip != tt.expectedIP {
|
||||
t.Errorf("Expected IP %s, got %s", tt.expectedIP, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityEventHandlers(t *testing.T) {
|
||||
t.Run("Logging security event handler", func(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
handler := NewLoggingSecurityEventHandler(logger)
|
||||
|
||||
event := SecurityEvent{
|
||||
Type: "authentication_failure",
|
||||
ClientIP: "192.168.1.1",
|
||||
Timestamp: time.Now(),
|
||||
Message: "Test failure",
|
||||
Severity: "medium",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
handler.HandleSecurityEvent(event)
|
||||
})
|
||||
|
||||
// Metrics security event handler test removed as part of metrics cleanup
|
||||
}
|
||||
|
||||
func TestSecurityMonitorEventHandlers(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Add event handler with proper synchronization
|
||||
handlerCalled := make(chan bool, 1)
|
||||
handler := &testSecurityEventHandler{
|
||||
callback: func(event SecurityEvent) {
|
||||
select {
|
||||
case handlerCalled <- true:
|
||||
default:
|
||||
// Channel already has a value, don't block
|
||||
}
|
||||
},
|
||||
}
|
||||
monitor.AddEventHandler(handler)
|
||||
|
||||
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "test failure", nil)
|
||||
|
||||
// Wait for event handler to be called with timeout
|
||||
select {
|
||||
case <-handlerCalled:
|
||||
// Success - handler was called
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Expected event handler to be called within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper for security event handler
|
||||
type testSecurityEventHandler struct {
|
||||
callback func(SecurityEvent)
|
||||
}
|
||||
|
||||
func (h *testSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
|
||||
h.callback(event)
|
||||
}
|
||||
|
||||
func TestDefaultSecurityMonitorConfig(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
|
||||
if config.MaxFailuresPerIP <= 0 {
|
||||
t.Error("Expected positive MaxFailuresPerIP")
|
||||
}
|
||||
if config.BlockDurationMinutes <= 0 {
|
||||
t.Error("Expected positive BlockDurationMinutes")
|
||||
}
|
||||
if config.CleanupIntervalMinutes <= 0 {
|
||||
t.Error("Expected positive CleanupIntervalMinutes")
|
||||
}
|
||||
if config.FailureWindowMinutes <= 0 {
|
||||
t.Error("Expected positive FailureWindowMinutes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityMonitorCleanup(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
config.CleanupIntervalMinutes = 1
|
||||
config.BlockDurationMinutes = 1
|
||||
config.RetentionHours = 1
|
||||
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Block an IP
|
||||
for i := 0; i < config.MaxFailuresPerIP; i++ {
|
||||
monitor.RecordAuthenticationFailure("192.168.1.99", "test-agent", "/login", "test", nil)
|
||||
}
|
||||
|
||||
// Verify it's blocked
|
||||
if !monitor.IsIPBlocked("192.168.1.99") {
|
||||
t.Error("IP should be blocked")
|
||||
}
|
||||
|
||||
// Wait a bit and check if it gets unblocked automatically
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// The IP should still be blocked since we haven't waited long enough
|
||||
if !monitor.IsIPBlocked("192.168.1.99") {
|
||||
t.Error("IP should still be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityEventTypes(t *testing.T) {
|
||||
config := DefaultSecurityMonitorConfig()
|
||||
logger := NewLogger("debug")
|
||||
monitor := NewSecurityMonitor(config, logger)
|
||||
|
||||
// Test different event types - just verify they don't panic
|
||||
monitor.RecordAuthenticationFailure("192.168.1.200", "test-agent", "/login", "invalid password", nil)
|
||||
monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123")
|
||||
monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api")
|
||||
|
||||
details := map[string]interface{}{"pattern": "test"}
|
||||
monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details)
|
||||
|
||||
// Just verify GetSecurityMetrics doesn't panic
|
||||
_ = monitor.GetSecurityMetrics()
|
||||
}
|
||||
+62
-8
@@ -4,7 +4,9 @@ import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
@@ -31,6 +33,45 @@ func constantTimeStringCompare(a, b string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||||
}
|
||||
|
||||
// deriveCookieKeys derives an independent 64-byte HMAC authentication key and a
|
||||
// 32-byte AES-256 encryption key from the operator-provided session encryption
|
||||
// key using HKDF-SHA256 (RFC 5869).
|
||||
//
|
||||
// gorilla/securecookie only ENCRYPTS the cookie payload when a block
|
||||
// (encryption) key is supplied; constructing the store with a single key leaves
|
||||
// sessions signed-but-plaintext, so the OIDC access/refresh/ID tokens stored in
|
||||
// the cookie are recoverable by anyone who can read the raw cookie bytes. Two
|
||||
// independent keys are derived here so the cookie is both encrypted and
|
||||
// authenticated. HKDF is implemented with stdlib hmac+sha256 so it runs under
|
||||
// Traefik's yaegi interpreter, which may not export crypto/hkdf.
|
||||
func deriveCookieKeys(secret string) (authKey, encKey []byte) {
|
||||
okm := hkdfSHA256([]byte(secret), nil, []byte("traefikoidc session cookie keys v1"), 96)
|
||||
return okm[:64], okm[64:96]
|
||||
}
|
||||
|
||||
// hkdfSHA256 performs HKDF-Extract followed by HKDF-Expand (RFC 5869) using
|
||||
// HMAC-SHA256 and returns length bytes of output keying material.
|
||||
func hkdfSHA256(ikm, salt, info []byte, length int) []byte {
|
||||
if len(salt) == 0 {
|
||||
salt = make([]byte, sha256.Size)
|
||||
}
|
||||
// Extract: PRK = HMAC-SHA256(salt, IKM)
|
||||
ext := hmac.New(sha256.New, salt)
|
||||
ext.Write(ikm)
|
||||
prk := ext.Sum(nil)
|
||||
// Expand: T(i) = HMAC-SHA256(PRK, T(i-1) | info | i)
|
||||
var out, t []byte
|
||||
for i := byte(1); len(out) < length; i++ {
|
||||
exp := hmac.New(sha256.New, prk)
|
||||
exp.Write(t)
|
||||
exp.Write(info)
|
||||
exp.Write([]byte{i})
|
||||
t = exp.Sum(nil)
|
||||
out = append(out, t...)
|
||||
}
|
||||
return out[:length]
|
||||
}
|
||||
|
||||
// min returns the minimum of two integers.
|
||||
// This is a utility function used throughout the session management code.
|
||||
// Parameters:
|
||||
@@ -118,12 +159,12 @@ var knownSessionKeys = map[string]bool{
|
||||
"id_token": true,
|
||||
"user_identifier": true,
|
||||
"authenticated": true,
|
||||
"csrf": true,
|
||||
"nonce": true,
|
||||
"code_verifier": true,
|
||||
"incoming_path": true,
|
||||
"created_at": true,
|
||||
"redirect_count": true,
|
||||
"csrf": true,
|
||||
"nonce": true,
|
||||
"code_verifier": true,
|
||||
"incoming_path": true,
|
||||
"created_at": true,
|
||||
"redirect_count": true,
|
||||
}
|
||||
|
||||
// compressCombinedPayload compresses the combined session payload using gzip.
|
||||
@@ -423,8 +464,13 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Derive independent authentication + encryption keys so the session cookie
|
||||
// is AES-256 encrypted and HMAC authenticated, not merely signed. See
|
||||
// deriveCookieKeys: a single key would leave the stored tokens in plaintext.
|
||||
authKey, encKey := deriveCookieKeys(encryptionKey)
|
||||
|
||||
sm := &SessionManager{
|
||||
store: sessions.NewCookieStore([]byte(encryptionKey)),
|
||||
store: sessions.NewCookieStore(authKey, encKey),
|
||||
forceHTTPS: forceHTTPS,
|
||||
cookieDomain: cookieDomain,
|
||||
cookiePrefix: cookiePrefix,
|
||||
@@ -435,6 +481,14 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Bind the cookie codec's timestamp validity (and the cookie Max-Age) to the
|
||||
// configured session lifetime instead of gorilla's 30-day default, so a
|
||||
// stolen cookie is not cryptographically valid for up to 30 days regardless
|
||||
// of the (possibly much shorter) configured sessionMaxAge (rank 13).
|
||||
if cs, ok := sm.store.(*sessions.CookieStore); ok {
|
||||
cs.MaxAge(int(sessionMaxAge.Seconds()))
|
||||
}
|
||||
|
||||
// Initialize global memory monitoring (singleton)
|
||||
sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger)
|
||||
|
||||
@@ -1566,7 +1620,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
|
||||
sd.sessionMutex.Lock()
|
||||
sd.clearAllSessionData(r, true)
|
||||
|
||||
|
||||
// Release the lock before calling Save to prevent deadlock
|
||||
sd.sessionMutex.Unlock()
|
||||
|
||||
|
||||
+28
-1
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -54,6 +55,7 @@ type Config struct {
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
ExtraAuthParams map[string]string `json:"extraAuthParams,omitempty"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
|
||||
// a stored refresh token. Once the token has been in the session longer
|
||||
@@ -761,7 +763,32 @@ func validateTemplateSecure(templateStr string) error {
|
||||
// Returns true if the URL is valid and secure (HTTPS), false otherwise.
|
||||
func isValidSecureURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme == "https" && u.Host != ""
|
||||
if err != nil || u.Host == "" {
|
||||
return false
|
||||
}
|
||||
if u.Scheme == "https" {
|
||||
return true
|
||||
}
|
||||
// Permit plaintext HTTP only for loopback hosts (local development,
|
||||
// in-cluster sidecar providers, tests). Loopback traffic never leaves the
|
||||
// host, so it is not exposed to network MITM; remote providers must use
|
||||
// HTTPS. Mirrors the RFC 8252 loopback allowance.
|
||||
if u.Scheme == "http" && isLoopbackHost(u.Hostname()) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isLoopbackHost reports whether host is "localhost" or a loopback IP literal
|
||||
// (127.0.0.0/8 or ::1).
|
||||
func isLoopbackHost(host string) bool {
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
|
||||
|
||||
+41
-11
@@ -106,8 +106,9 @@ func (rm *ResourceManager) GetCache(key string) interface{} {
|
||||
case "jwk-cache":
|
||||
cache = cacheManager.GetSharedJWKCache()
|
||||
default:
|
||||
// Generic cache implementation
|
||||
cache = NewGenericCache(1*time.Hour, rm.logger)
|
||||
// Generic cache implementation; bind cleanup goroutine to the manager's
|
||||
// shutdown channel so it exits when the ResourceManager shuts down.
|
||||
cache = newGenericCacheWithOwner(1*time.Hour, rm.logger, rm.shutdownChan)
|
||||
}
|
||||
|
||||
rm.caches[key] = cache
|
||||
@@ -263,6 +264,19 @@ func (rm *ResourceManager) cleanupInstance(instanceID string) {
|
||||
// This is a hook for future instance-specific cleanup needs
|
||||
}
|
||||
|
||||
// liveInstanceCount tracks the number of fully-constructed TraefikOidc plugin
|
||||
// instances alive in this process. Process-global singleton tasks (such as the
|
||||
// shared token-cleanup) must only be stopped when the LAST instance shuts down,
|
||||
// otherwise one instance's teardown would disable them for all survivors.
|
||||
var liveInstanceCount int32
|
||||
|
||||
// registerLiveInstance records a newly constructed plugin instance.
|
||||
func registerLiveInstance() { atomic.AddInt32(&liveInstanceCount, 1) }
|
||||
|
||||
// unregisterLiveInstance records a plugin instance shutting down and returns the
|
||||
// number of instances still alive afterwards.
|
||||
func unregisterLiveInstance() int32 { return atomic.AddInt32(&liveInstanceCount, -1) }
|
||||
|
||||
// Shutdown gracefully shuts down all managed resources
|
||||
func (rm *ResourceManager) Shutdown(ctx context.Context) error {
|
||||
var err error
|
||||
@@ -501,20 +515,31 @@ func (p *GoroutinePool) Shutdown(ctx context.Context) error {
|
||||
|
||||
// GenericCache provides a simple cache implementation for testing
|
||||
type GenericCache struct {
|
||||
data map[string]interface{}
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
data map[string]interface{}
|
||||
// ownerStopChan, when non-nil, signals the cleanup goroutine to exit when
|
||||
// the owning ResourceManager shuts down, so the goroutine cannot outlive it.
|
||||
ownerStopChan <-chan struct{}
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewGenericCache creates a new generic cache
|
||||
func NewGenericCache(ttl time.Duration, logger *Logger) *GenericCache {
|
||||
return newGenericCacheWithOwner(ttl, logger, nil)
|
||||
}
|
||||
|
||||
// newGenericCacheWithOwner creates a generic cache whose cleanup goroutine also
|
||||
// exits when ownerStopChan is closed (typically the ResourceManager shutdown
|
||||
// channel), guaranteeing the goroutine is stoppable on shutdown.
|
||||
func newGenericCacheWithOwner(ttl time.Duration, logger *Logger, ownerStopChan <-chan struct{}) *GenericCache {
|
||||
cache := &GenericCache{
|
||||
data: make(map[string]interface{}),
|
||||
ttl: ttl,
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
data: make(map[string]interface{}),
|
||||
ttl: ttl,
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
ownerStopChan: ownerStopChan,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
@@ -570,6 +595,11 @@ func (gc *GenericCache) cleanupRoutine() {
|
||||
gc.mu.Unlock()
|
||||
case <-gc.stopChan:
|
||||
return
|
||||
case <-gc.ownerStopChan:
|
||||
// Owning ResourceManager is shutting down; exit so the goroutine
|
||||
// does not outlive its owner. A nil channel blocks forever, so this
|
||||
// case is inert when no owner is set.
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+27
-15
@@ -296,9 +296,12 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
|
||||
// Create a TraefikOidc instance with context
|
||||
config := &Config{
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
plugin, err := NewWithContext(ctx, config, nil, "test")
|
||||
@@ -350,9 +353,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
configs := []Config{
|
||||
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
|
||||
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
|
||||
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
|
||||
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
|
||||
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
|
||||
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
|
||||
}
|
||||
|
||||
var plugins []*TraefikOidc
|
||||
@@ -432,9 +435,12 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
for i := 0; i < 3; i++ {
|
||||
ctx := context.Background()
|
||||
config := &Config{
|
||||
ProviderURL: mockServers[i].URL,
|
||||
ClientID: fmt.Sprintf("client%d", i),
|
||||
ClientSecret: fmt.Sprintf("secret%d", i),
|
||||
ProviderURL: mockServers[i].URL,
|
||||
ClientID: fmt.Sprintf("client%d", i),
|
||||
ClientSecret: fmt.Sprintf("secret%d", i),
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
plugin, err := NewWithContext(ctx, config, nil, fmt.Sprintf("test-%d", i))
|
||||
@@ -595,9 +601,12 @@ func TestBackwardCompatibility(t *testing.T) {
|
||||
t.Run("LegacyNewFunction", func(t *testing.T) {
|
||||
// Test that the original New function still works
|
||||
config := &Config{
|
||||
ProviderURL: "https://example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
ProviderURL: "https://example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
handler, err := New(context.Background(), nil, config, "test")
|
||||
@@ -617,9 +626,12 @@ func TestBackwardCompatibility(t *testing.T) {
|
||||
|
||||
t.Run("ExistingAPICompatibility", func(t *testing.T) {
|
||||
config := &Config{
|
||||
ProviderURL: "https://example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
ProviderURL: "https://example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
RateLimit: 100,
|
||||
}
|
||||
|
||||
handler, _ := New(context.Background(), nil, config, "test")
|
||||
|
||||
-142
@@ -1,142 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// pluginVersion is bumped manually on each release. Keep in sync with the
|
||||
// most recent git tag (see `git tag --sort=-v:refname | head -1`).
|
||||
const pluginVersion = "1.0.11"
|
||||
|
||||
const (
|
||||
telemetryProject = "traefikoidc"
|
||||
telemetryTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// telemetryEndpoint is intentionally a var rather than a const so the test
|
||||
// suite in this package can retarget it at an httptest server. Production
|
||||
// code never mutates it.
|
||||
var telemetryEndpoint = "https://oss.raczylo.com/v1/ping"
|
||||
|
||||
// telemetryOnce guarantees a single anonymous "plugin loaded" ping per
|
||||
// process lifetime. Traefik can instantiate a middleware many times per
|
||||
// process (one per route using the plugin); the sync.Once gate keeps the
|
||||
// fire-and-forget call from amplifying into many pings.
|
||||
//
|
||||
// Reset in tests via `telemetryOnce = sync.Once{}`.
|
||||
var telemetryOnce sync.Once
|
||||
|
||||
// telemetryInflight tracks any background goroutine started by sendTelemetry.
|
||||
// Tests Wait on it to drain in-flight goroutines before mutating package
|
||||
// state. Production code never calls Wait — the goroutine is fire-and-forget.
|
||||
var telemetryInflight sync.WaitGroup
|
||||
|
||||
// sendTelemetry fires one anonymous usage ping in the background. It is
|
||||
// failproof by contract:
|
||||
//
|
||||
// - never blocks the caller
|
||||
// - never panics (the goroutine recovers internally)
|
||||
// - never returns errors
|
||||
// - silently dropped on invalid input, env-driven opt-out, or network failure
|
||||
//
|
||||
// Opt-out is honored via any of:
|
||||
//
|
||||
// - DO_NOT_TRACK=1
|
||||
// - OSS_TELEMETRY_DISABLED=1
|
||||
// - TRAEFIKOIDC_DISABLE_TELEMETRY=1
|
||||
//
|
||||
// Yaegi note: this file deliberately avoids generics (atomic.Pointer[T]) and
|
||||
// range-over-int (Go 1.22) so it interprets under any reasonably recent
|
||||
// Traefik yaegi runtime.
|
||||
func sendTelemetry(version string) {
|
||||
telemetryOnce.Do(func() {
|
||||
if telemetryDisabledByEnv() {
|
||||
return
|
||||
}
|
||||
if !validTelemetryVersion(version) {
|
||||
return
|
||||
}
|
||||
telemetryInflight.Add(1)
|
||||
go func() {
|
||||
defer telemetryInflight.Done()
|
||||
defer func() { _ = recover() }()
|
||||
doTelemetryPost(version)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func telemetryDisabledByEnv() bool {
|
||||
keys := []string{
|
||||
"DO_NOT_TRACK",
|
||||
"OSS_TELEMETRY_DISABLED",
|
||||
"TRAEFIKOIDC_DISABLE_TELEMETRY",
|
||||
}
|
||||
for _, k := range keys {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv(k)))
|
||||
if v == "1" || v == "true" || v == "yes" || v == "on" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// validTelemetryVersion mirrors the server-side regex ^[A-Za-z0-9.+_-]{1,32}$
|
||||
// using a byte loop. No allocation, no regexp dependency.
|
||||
//
|
||||
// Yaegi note: written as an `||` chain rather than `switch{case A,B,C:}` —
|
||||
// some yaegi releases mis-evaluate comma-separated case expressions in
|
||||
// switch-true blocks, returning false for all inputs.
|
||||
func validTelemetryVersion(v string) bool {
|
||||
if len(v) == 0 || len(v) > 32 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(v); i++ {
|
||||
c := v[i]
|
||||
ok := (c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '.' || c == '+' || c == '_' || c == '-'
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// doTelemetryPost builds the JSON body manually. The project name is a
|
||||
// constant and the version is pre-validated against an ASCII-only allowlist,
|
||||
// so direct concatenation needs no JSON escaping.
|
||||
func doTelemetryPost(version string) {
|
||||
body := make([]byte, 0, 96)
|
||||
body = append(body, `{"project":"`...)
|
||||
body = append(body, telemetryProject...)
|
||||
body = append(body, `","version":"`...)
|
||||
body = append(body, version...)
|
||||
body = append(body, `","ts":`...)
|
||||
body = strconv.AppendInt(body, time.Now().Unix(), 10)
|
||||
body = append(body, '}')
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), telemetryTimeout)
|
||||
defer cancel()
|
||||
|
||||
url := telemetryEndpoint
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: telemetryTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
@@ -1,167 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// resetTelemetryState restores package-level mutable state so tests do not
|
||||
// contaminate one another. The cleanup waits for any in-flight ping goroutine
|
||||
// to finish before restoring telemetryEndpoint — without that drain step the
|
||||
// goroutine and the cleanup would race on the var.
|
||||
func resetTelemetryState(t *testing.T) {
|
||||
t.Helper()
|
||||
telemetryOnce = sync.Once{}
|
||||
prev := telemetryEndpoint
|
||||
t.Cleanup(func() {
|
||||
telemetryInflight.Wait()
|
||||
telemetryEndpoint = prev
|
||||
telemetryOnce = sync.Once{}
|
||||
})
|
||||
}
|
||||
|
||||
func newTelemetryServer(t *testing.T, status int) (hits *int32, lastBody func() string) {
|
||||
t.Helper()
|
||||
var counter int32
|
||||
var mu sync.Mutex
|
||||
var body string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
mu.Lock()
|
||||
body = string(b)
|
||||
mu.Unlock()
|
||||
w.WriteHeader(status)
|
||||
}))
|
||||
telemetryEndpoint = srv.URL
|
||||
t.Cleanup(srv.Close)
|
||||
return &counter, func() string {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return body
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidTelemetryVersion(t *testing.T) {
|
||||
good := []string{"1.2.3", "1.4.0-beta1", "2.0", "v1.0.0", "1.0.0+meta", "dev"}
|
||||
for _, v := range good {
|
||||
if !validTelemetryVersion(v) {
|
||||
t.Errorf("validTelemetryVersion(%q) = false, want true", v)
|
||||
}
|
||||
}
|
||||
bad := []string{"", "has space", "semi;colon", strings.Repeat("1", 33)}
|
||||
for _, v := range bad {
|
||||
if validTelemetryVersion(v) {
|
||||
t.Errorf("validTelemetryVersion(%q) = true, want false", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTelemetryDisabledByEnv(t *testing.T) {
|
||||
for _, k := range []string{"DO_NOT_TRACK", "OSS_TELEMETRY_DISABLED", "TRAEFIKOIDC_DISABLE_TELEMETRY"} {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv(k, "1")
|
||||
if !telemetryDisabledByEnv() {
|
||||
t.Fatalf("%s=1 should disable", k)
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Run("falsy_values_do_not_disable", func(t *testing.T) {
|
||||
t.Setenv("DO_NOT_TRACK", "0")
|
||||
t.Setenv("OSS_TELEMETRY_DISABLED", "false")
|
||||
t.Setenv("TRAEFIKOIDC_DISABLE_TELEMETRY", "no")
|
||||
if telemetryDisabledByEnv() {
|
||||
t.Fatal("falsy env values should not disable")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendTelemetry_FiresOnceAcrossManyCalls(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
hits, lastBody := newTelemetryServer(t, http.StatusNoContent)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
sendTelemetry("1.2.3")
|
||||
}
|
||||
telemetryInflight.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(hits); got != 1 {
|
||||
t.Fatalf("expected exactly 1 hit, got %d", got)
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Project string `json:"project"`
|
||||
Version string `json:"version"`
|
||||
Ts int64 `json:"ts"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(lastBody()), &payload); err != nil {
|
||||
t.Fatalf("server received non-JSON body: %q (err: %v)", lastBody(), err)
|
||||
}
|
||||
if payload.Project != "traefikoidc" || payload.Version != "1.2.3" || payload.Ts <= 0 {
|
||||
t.Fatalf("unexpected payload: %+v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTelemetry_RespectsDisableEnv(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
hits, _ := newTelemetryServer(t, http.StatusNoContent)
|
||||
t.Setenv("DO_NOT_TRACK", "1")
|
||||
|
||||
sendTelemetry("1.2.3")
|
||||
telemetryInflight.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(hits); got != 0 {
|
||||
t.Fatalf("DO_NOT_TRACK should suppress; got %d hits", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTelemetry_DropsInvalidVersion(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
hits, _ := newTelemetryServer(t, http.StatusNoContent)
|
||||
|
||||
sendTelemetry("has space")
|
||||
telemetryInflight.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(hits); got != 0 {
|
||||
t.Fatalf("invalid version should suppress; got %d hits", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTelemetry_DoesNotBlock(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
// Hanging server proves the caller is never blocked. The 2s context
|
||||
// timeout in doTelemetryPost ensures the goroutine eventually exits;
|
||||
// resetTelemetryState's cleanup waits for that drain before restoring
|
||||
// telemetryEndpoint so there is no race with this test's mutation.
|
||||
hung := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
time.Sleep(5 * time.Second)
|
||||
}))
|
||||
t.Cleanup(hung.Close)
|
||||
telemetryEndpoint = hung.URL
|
||||
|
||||
start := time.Now()
|
||||
sendTelemetry("1.2.3")
|
||||
if elapsed := time.Since(start); elapsed > 50*time.Millisecond {
|
||||
t.Fatalf("sendTelemetry blocked for %v, expected near-instant return", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTelemetry_SurvivesServerError(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
hits, _ := newTelemetryServer(t, http.StatusInternalServerError)
|
||||
|
||||
sendTelemetry("1.2.3")
|
||||
telemetryInflight.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(hits); got != 1 {
|
||||
t.Fatalf("request should still reach server even on 500; got %d hits", got)
|
||||
}
|
||||
}
|
||||
+29
-14
@@ -21,13 +21,16 @@ type IntrospectionResponse struct {
|
||||
Username string `json:"username,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Aud string `json:"aud,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
Jti string `json:"jti,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
Nbf int64 `json:"nbf,omitempty"`
|
||||
Active bool `json:"active"`
|
||||
// Aud holds the introspection audience. Per RFC 7662 it may be a single
|
||||
// string or an array of strings, so it is decoded as interface{} and
|
||||
// matched with verifyAudience (which handles both shapes).
|
||||
Aud interface{} `json:"aud,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
Jti string `json:"jti,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
Nbf int64 `json:"nbf,omitempty"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
// introspectToken performs OAuth 2.0 Token Introspection (RFC 7662) for an opaque token.
|
||||
@@ -120,7 +123,7 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err
|
||||
|
||||
// Parse response per RFC 7662 Section 2.2
|
||||
var introspectionResp IntrospectionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&introspectionResp); err != nil {
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&introspectionResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode introspection response: %w", err)
|
||||
}
|
||||
|
||||
@@ -128,6 +131,12 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err
|
||||
if t.introspectionCache != nil {
|
||||
// Cache for a short duration or until token expiry (whichever is shorter)
|
||||
cacheDuration := 5 * time.Minute
|
||||
// When introspection is REQUIRED, operators expect near-real-time
|
||||
// revocation; cap the positive-result cache so a token revoked at the
|
||||
// provider cannot keep passing for the full 5 minutes (rank 8).
|
||||
if t.requireTokenIntrospection && cacheDuration > 30*time.Second {
|
||||
cacheDuration = 30 * time.Second
|
||||
}
|
||||
if introspectionResp.Exp > 0 {
|
||||
expTime := time.Unix(introspectionResp.Exp, 0)
|
||||
untilExp := time.Until(expTime)
|
||||
@@ -197,12 +206,18 @@ func (t *TraefikOidc) validateOpaqueToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate audience if configured
|
||||
// Note: For opaque tokens, audience validation via introspection may be limited
|
||||
// depending on what the introspection endpoint returns
|
||||
if t.audience != "" && t.audience != t.clientID && resp.Aud != "" {
|
||||
if resp.Aud != t.audience {
|
||||
return fmt.Errorf("invalid audience: expected %s, got %s", t.audience, resp.Aud)
|
||||
// Validate audience if configured. When a distinct API audience is
|
||||
// configured (audience != clientID), the introspection response MUST carry
|
||||
// a matching audience. Fail closed on a missing or mismatched aud: a token
|
||||
// whose audience cannot be confirmed must not be accepted, otherwise a
|
||||
// token minted for a different audience would pass. aud may be a single
|
||||
// string or an array of strings (RFC 7662); verifyAudience handles both.
|
||||
if t.audience != "" && t.audience != t.clientID {
|
||||
if resp.Aud == nil {
|
||||
return fmt.Errorf("invalid audience: expected %s, introspection response has no audience", t.audience)
|
||||
}
|
||||
if err := verifyAudience(resp.Aud, t.audience); err != nil {
|
||||
return fmt.Errorf("invalid audience: expected %s: %w", t.audience, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+9
-6
@@ -5,6 +5,8 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -212,11 +214,13 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
|
||||
//
|
||||
//nolint:gocognit,gocyclo // Complex token type detection with multiple provider-specific checks
|
||||
func (t *TraefikOidc) detectTokenType(jwt *JWT, token string) bool {
|
||||
// Use first 32 chars of token as cache key (sufficient for uniqueness)
|
||||
cacheKey := token
|
||||
if len(token) > 32 {
|
||||
cacheKey = token[:32]
|
||||
}
|
||||
// Key on a hash of the FULL token. The first 32 characters of a JWT are
|
||||
// only the base64url-encoded header, which is identical for every token
|
||||
// sharing the same alg+kid, so distinct tokens (e.g. an ID token and an
|
||||
// access token from the same issuer) would otherwise collide on the cache
|
||||
// key and be mis-classified.
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
cacheKey := hex.EncodeToString(sum[:])
|
||||
|
||||
// Check cache first
|
||||
if t.tokenTypeCache != nil {
|
||||
@@ -858,7 +862,6 @@ func (t *TraefikOidc) isAzureProvider() bool {
|
||||
strings.Contains(issuerURL, "login.windows.net")
|
||||
}
|
||||
|
||||
|
||||
// startTokenCleanup starts background cleanup goroutines for cache maintenance.
|
||||
// It runs periodic cleanup of token cache, JWK cache, and session chunks.
|
||||
// Includes panic recovery to ensure stability.
|
||||
|
||||
+5
-535
@@ -3,7 +3,9 @@ package traefikoidc
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -885,10 +887,9 @@ func TestDetectTokenTypeCaching(t *testing.T) {
|
||||
},
|
||||
}
|
||||
token := "test-token-for-caching-with-enough-characters-for-key"
|
||||
cacheKey := token
|
||||
if len(token) > 32 {
|
||||
cacheKey = token[:32]
|
||||
}
|
||||
// The cache key is a SHA-256 hash of the full token (collision-resistant).
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
cacheKey := hex.EncodeToString(sum[:])
|
||||
|
||||
result := tr.detectTokenType(jwt, token)
|
||||
if !result {
|
||||
@@ -911,521 +912,6 @@ func TestDetectTokenTypeCaching(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TOKEN VALIDATOR TESTS
|
||||
// =============================================================================
|
||||
|
||||
func TestNewTokenValidator(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
if validator == nil {
|
||||
t.Fatal("Expected non-nil token validator")
|
||||
}
|
||||
|
||||
if validator.logger == nil {
|
||||
t.Error("Expected logger to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTokenValidatorWithLogger(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
validator := NewTokenValidator(logger)
|
||||
|
||||
if validator == nil {
|
||||
t.Fatal("Expected non-nil token validator")
|
||||
}
|
||||
|
||||
if validator.logger != logger {
|
||||
t.Error("Expected provided logger to be used")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenEmpty(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
result := validator.ValidateToken("", false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for empty token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for empty token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "empty") {
|
||||
t.Errorf("Expected 'empty' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenRequireJWT(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
result := validator.ValidateToken("opaque_token_value_here", true)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for opaque token when JWT required")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error when JWT required but opaque token provided")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTValidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if !result.Valid {
|
||||
t.Errorf("Expected valid result, got error: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.TokenType != "JWT" {
|
||||
t.Errorf("Expected TokenType 'JWT', got %s", result.TokenType)
|
||||
}
|
||||
|
||||
if result.Claims == nil {
|
||||
t.Error("Expected claims to be parsed")
|
||||
}
|
||||
|
||||
if result.Expiry == nil {
|
||||
t.Error("Expected expiry to be extracted")
|
||||
}
|
||||
|
||||
if result.IssuedAt == nil {
|
||||
t.Error("Expected issued at to be extracted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTExpiredToken(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for expired token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for expired token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "expired") {
|
||||
t.Errorf("Expected 'expired' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTFutureIssuedAt(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(2 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(10 * time.Minute).Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for future iat")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for future iat")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "future") {
|
||||
t.Errorf("Expected 'future' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTNotBeforeClaim(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(2 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"nbf": time.Now().Add(1 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for nbf in future")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for nbf in future")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "not yet valid") {
|
||||
t.Errorf("Expected 'not yet valid' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTInvalidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"single part", "eyJhbGciOiJIUzI1NiJ9"},
|
||||
{"two parts", "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0In0"},
|
||||
{"four parts", "part1.part2.part3.part4"},
|
||||
{"empty part", "eyJhbGciOiJIUzI1NiJ9..signature"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateToken(tt.token, true)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for malformed JWT")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for malformed JWT")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenValid(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "sk_live_abcdef123456GHIJKL789"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if !result.Valid {
|
||||
t.Errorf("Expected valid result, got error: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.TokenType != "Opaque" {
|
||||
t.Errorf("Expected TokenType 'Opaque', got %s", result.TokenType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenTooShort(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "short"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for short token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for short token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "too short") {
|
||||
t.Errorf("Expected 'too short' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenWithSpaces(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "this token has spaces in it"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for token with spaces")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for token with spaces")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "spaces") {
|
||||
t.Errorf("Expected 'spaces' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenControlCharacters(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "token_with\x00control_char"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for token with control characters")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for token with control characters")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "control character") {
|
||||
t.Errorf("Expected 'control character' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenInsufficientEntropy(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "aaaaaabbbbbbccccccdddd"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for low entropy token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for low entropy token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "entropy") {
|
||||
t.Errorf("Expected 'entropy' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidBase64URL(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"valid uppercase", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", true},
|
||||
{"valid lowercase", "abcdefghijklmnopqrstuvwxyz", true},
|
||||
{"valid numbers", "0123456789", true},
|
||||
{"valid dash", "abc-def", true},
|
||||
{"valid underscore", "abc_def", true},
|
||||
{"valid equals", "abc=", true},
|
||||
{"invalid at sign", "abc@def", false},
|
||||
{"invalid space", "abc def", false},
|
||||
{"invalid plus", "abc+def", false},
|
||||
{"invalid slash", "abc/def", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.isValidBase64URL(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v for %s, got %v", tt.expected, tt.input, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTime(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
claim interface{}
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{name: "float64", claim: float64(1609459200), expected: true},
|
||||
{name: "int64", claim: int64(1609459200), expected: true},
|
||||
{name: "int", claim: int(1609459200), expected: true},
|
||||
{name: "string", claim: "not a timestamp", expected: false},
|
||||
{name: "nil", claim: nil, expected: false},
|
||||
{name: "map", claim: map[string]interface{}{}, expected: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.extractTime(tt.claim)
|
||||
|
||||
if tt.expected && result == nil {
|
||||
t.Error("Expected non-nil time")
|
||||
}
|
||||
|
||||
if !tt.expected && result != nil {
|
||||
t.Error("Expected nil time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenSize(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
maxSize int
|
||||
expectError bool
|
||||
}{
|
||||
{"within limit", "short_token", 20, false},
|
||||
{"at limit", "exactly_twenty_c", 16, false},
|
||||
{"exceeds limit", "this_token_is_too_long", 10, true},
|
||||
{"empty token", "", 10, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateTokenSize(tt.token, tt.maxSize)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error for oversized token")
|
||||
}
|
||||
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if err != nil && !strings.Contains(err.Error(), "exceeds") {
|
||||
t.Errorf("Expected 'exceeds' in error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaims(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(1609459200),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
extracted, err := validator.ExtractClaims(token)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extracted == nil {
|
||||
t.Fatal("Expected non-nil claims")
|
||||
}
|
||||
|
||||
if extracted["sub"] != "user123" {
|
||||
t.Errorf("Expected sub 'user123', got %v", extracted["sub"])
|
||||
}
|
||||
|
||||
if extracted["email"] != "user@example.com" {
|
||||
t.Errorf("Expected email 'user@example.com', got %v", extracted["email"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaimsInvalidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"single part", "onlyonepart"},
|
||||
{"two parts", "two.parts"},
|
||||
{"four parts", "one.two.three.four"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := validator.ExtractClaims(tt.token)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid format")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid JWT format") {
|
||||
t.Errorf("Expected 'invalid JWT format' in error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensEqual(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "secret_token_12345"
|
||||
token2 := "secret_token_12345"
|
||||
|
||||
if !validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensDifferent(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "secret_token_12345"
|
||||
token2 := "secret_token_54321"
|
||||
|
||||
if validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be different")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensDifferentLength(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "short"
|
||||
token2 := "much_longer_token"
|
||||
|
||||
if validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be different (different lengths)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensEmpty(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := ""
|
||||
token2 := ""
|
||||
|
||||
if !validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected empty tokens to be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenMaliciousPayloads(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"sql injection attempt", "'; DROP TABLE users; --"},
|
||||
{"xss attempt", "<script>alert('xss')</script>"},
|
||||
{"path traversal", "../../../etc/passwd"},
|
||||
{"null bytes", "token\x00with\x00nulls"},
|
||||
{"unicode exploit", "token\u0000\u0001\u0002"},
|
||||
{"extremely long", strings.Repeat("a", 100000)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateToken(tt.token, false)
|
||||
|
||||
if result.Valid {
|
||||
if result.Claims != nil {
|
||||
t.Logf("Token considered valid: %s", tt.name)
|
||||
}
|
||||
} else {
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for malicious payload")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CONSOLIDATED TOKEN TESTS
|
||||
// =============================================================================
|
||||
@@ -2098,19 +1584,3 @@ func createTokenOfSize(baseToken string, targetSize int) string {
|
||||
|
||||
return baseToken
|
||||
}
|
||||
|
||||
func createTestJWTSimple(claims map[string]interface{}) string {
|
||||
header := map[string]interface{}{
|
||||
"alg": "HS256",
|
||||
"typ": "JWT",
|
||||
}
|
||||
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("fake_signature"))
|
||||
|
||||
return headerB64 + "." + claimsB64 + "." + signature
|
||||
}
|
||||
|
||||
+15
-2
@@ -149,7 +149,15 @@ func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bo
|
||||
if rs.idToken != "" {
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
return true, false, false
|
||||
// No ID token to corroborate an access token we cannot verify
|
||||
// (Azure nonce-bearing Graph access tokens carry a proprietary,
|
||||
// client-unverifiable signature). Do NOT authenticate on an
|
||||
// unverified token: refresh if a refresh token is available,
|
||||
// otherwise force re-authentication.
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +166,12 @@ func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bo
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return true, false, false
|
||||
// Opaque access token, no ID token to corroborate it, and
|
||||
// introspection was unavailable/disabled/errored (e.g.
|
||||
// circuit-breaker open). There is nothing left to verify the token
|
||||
// against, so fail closed and force re-authentication rather than
|
||||
// trusting an unverified opaque token.
|
||||
return false, false, true
|
||||
}
|
||||
if err := t.verifyToken(rs.idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/pool"
|
||||
)
|
||||
|
||||
// TokenValidator provides unified token validation functionality
|
||||
type TokenValidator struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewTokenValidator creates a new token validator
|
||||
func NewTokenValidator(logger *Logger) *TokenValidator {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
return &TokenValidator{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TokenValidationResult contains the result of token validation
|
||||
type TokenValidationResult struct {
|
||||
Error error
|
||||
Claims map[string]interface{}
|
||||
Expiry *time.Time
|
||||
IssuedAt *time.Time
|
||||
TokenType string
|
||||
Valid bool
|
||||
}
|
||||
|
||||
// ValidateToken performs comprehensive token validation
|
||||
func (v *TokenValidator) ValidateToken(token string, requireJWT bool) TokenValidationResult {
|
||||
result := TokenValidationResult{}
|
||||
|
||||
// Basic validation
|
||||
if token == "" {
|
||||
result.Error = fmt.Errorf("token is empty")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check if it's a JWT or opaque token
|
||||
dotCount := strings.Count(token, ".")
|
||||
isJWT := dotCount == 2
|
||||
|
||||
if requireJWT && !isJWT {
|
||||
result.Error = fmt.Errorf("token is not a valid JWT (found %d dots, expected 2)", dotCount)
|
||||
return result
|
||||
}
|
||||
|
||||
if isJWT {
|
||||
return v.validateJWT(token)
|
||||
} else {
|
||||
return v.validateOpaqueToken(token)
|
||||
}
|
||||
}
|
||||
|
||||
// validateJWT validates a JWT token
|
||||
func (v *TokenValidator) validateJWT(token string) TokenValidationResult {
|
||||
result := TokenValidationResult{
|
||||
TokenType: "JWT",
|
||||
}
|
||||
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
result.Error = fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate each part
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
result.Error = fmt.Errorf("JWT part %d is empty", i)
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for valid base64url characters
|
||||
if !v.isValidBase64URL(part) {
|
||||
result.Error = fmt.Errorf("JWT part %d contains invalid base64url characters", i)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Decode and parse claims
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
result.Error = fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
pm := pool.Get()
|
||||
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
|
||||
defer pm.PutJSONDecoder(decoder)
|
||||
if err := decoder.Decode(&claims); err != nil {
|
||||
result.Error = fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
return result
|
||||
}
|
||||
|
||||
result.Claims = claims
|
||||
|
||||
// Extract standard claims
|
||||
if exp, ok := claims["exp"]; ok {
|
||||
expTime := v.extractTime(exp)
|
||||
if expTime != nil {
|
||||
result.Expiry = expTime
|
||||
// Check if expired
|
||||
if time.Now().After(*expTime) {
|
||||
result.Error = fmt.Errorf("token is expired (expired at %v)", expTime.Format(time.RFC3339))
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if iat, ok := claims["iat"]; ok {
|
||||
iatTime := v.extractTime(iat)
|
||||
if iatTime != nil {
|
||||
result.IssuedAt = iatTime
|
||||
// Check if issued in future
|
||||
if iatTime.After(time.Now().Add(5 * time.Minute)) {
|
||||
result.Error = fmt.Errorf("token issued in future (iat: %v)", iatTime.Format(time.RFC3339))
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check nbf (not before)
|
||||
if nbf, ok := claims["nbf"]; ok {
|
||||
nbfTime := v.extractTime(nbf)
|
||||
if nbfTime != nil && time.Now().Before(*nbfTime) {
|
||||
result.Error = fmt.Errorf("token not yet valid (nbf: %v)", nbfTime.Format(time.RFC3339))
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
result.Valid = true
|
||||
return result
|
||||
}
|
||||
|
||||
// validateOpaqueToken validates an opaque token
|
||||
func (v *TokenValidator) validateOpaqueToken(token string) TokenValidationResult {
|
||||
result := TokenValidationResult{
|
||||
TokenType: "Opaque",
|
||||
}
|
||||
|
||||
// Check minimum length
|
||||
if len(token) < 20 {
|
||||
result.Error = fmt.Errorf("opaque token too short (length: %d)", len(token))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for spaces
|
||||
if strings.Contains(token, " ") {
|
||||
result.Error = fmt.Errorf("opaque token contains spaces")
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for control characters
|
||||
for i, char := range token {
|
||||
if char < 32 || char == 127 {
|
||||
result.Error = fmt.Errorf("opaque token contains control character at position %d", i)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Check entropy
|
||||
if len(token) >= 20 {
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, char := range token {
|
||||
uniqueChars[char] = true
|
||||
}
|
||||
if len(uniqueChars) < 8 {
|
||||
result.Error = fmt.Errorf("opaque token has insufficient entropy (unique chars: %d)", len(uniqueChars))
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
result.Valid = true
|
||||
return result
|
||||
}
|
||||
|
||||
// isValidBase64URL checks if a string contains only valid base64url characters
|
||||
func (v *TokenValidator) isValidBase64URL(s string) bool {
|
||||
for _, char := range s {
|
||||
if !((char >= 'A' && char <= 'Z') ||
|
||||
(char >= 'a' && char <= 'z') ||
|
||||
(char >= '0' && char <= '9') ||
|
||||
char == '-' || char == '_' || char == '=') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// extractTime extracts a time.Time from various claim formats
|
||||
func (v *TokenValidator) extractTime(claim interface{}) *time.Time {
|
||||
var timestamp int64
|
||||
|
||||
switch val := claim.(type) {
|
||||
case float64:
|
||||
timestamp = int64(val)
|
||||
case int64:
|
||||
timestamp = val
|
||||
case int:
|
||||
timestamp = int64(val)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
t := time.Unix(timestamp, 0)
|
||||
return &t
|
||||
}
|
||||
|
||||
// ValidateTokenSize checks if token size is within acceptable limits
|
||||
func (v *TokenValidator) ValidateTokenSize(token string, maxSize int) error {
|
||||
if len(token) > maxSize {
|
||||
return fmt.Errorf("token exceeds maximum size (size: %d, max: %d)", len(token), maxSize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractClaims extracts claims from a JWT without full validation
|
||||
func (v *TokenValidator) ExtractClaims(token string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
pm := pool.Get()
|
||||
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
|
||||
defer pm.PutJSONDecoder(decoder)
|
||||
if err := decoder.Decode(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// CompareTokens safely compares two tokens for equality
|
||||
func (v *TokenValidator) CompareTokens(token1, token2 string) bool {
|
||||
if len(token1) != len(token2) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use constant-time comparison to prevent timing attacks
|
||||
var result byte
|
||||
for i := 0; i < len(token1); i++ {
|
||||
result |= token1[i] ^ token2[i]
|
||||
}
|
||||
return result == 0
|
||||
}
|
||||
@@ -165,6 +165,7 @@ type TraefikOidc struct {
|
||||
frontchannelLogoutPath string
|
||||
scopesSupported []string
|
||||
scopes []string
|
||||
extraAuthParams map[string]string
|
||||
refreshGracePeriod time.Duration
|
||||
maxRefreshTokenAge time.Duration
|
||||
metadataMu sync.RWMutex
|
||||
|
||||
+55
-10
@@ -396,8 +396,16 @@ func (c *UniversalCache) getLocal(key string) (interface{}, bool) {
|
||||
return value, true
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
// Expired — fall through to the write-locked slow path below to
|
||||
// remove the entry under exclusive access.
|
||||
// Expired — return miss immediately. The periodic cleanup goroutine
|
||||
// will evict the stale entry. NEVER fall through to the write-locked
|
||||
// slow path for Token/JWK/Session caches: under Yaegi the write Lock
|
||||
// at line 403 costs 10-100ms per acquisition, and Go's RWMutex
|
||||
// writer-priority semantics block ALL new RLock callers while a Lock
|
||||
// is pending. A single expired-token event turns every concurrent
|
||||
// request from read-parallel into write-serialized — the exact
|
||||
// convoy that produced the 737-goroutine pileup at 0x400275a608.
|
||||
atomic.AddInt64(&c.misses, 1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -595,15 +603,28 @@ func (c *UniversalCache) removeItem(key string, item *CacheItem) {
|
||||
|
||||
// evictOldest evicts the oldest item from the cache (must be called with lock held)
|
||||
func (c *UniversalCache) evictOldest() {
|
||||
if elem := c.lruList.Back(); elem != nil {
|
||||
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
|
||||
if item, exists := c.items[key]; exists {
|
||||
c.removeItem(key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
|
||||
}
|
||||
elem := c.lruList.Back()
|
||||
if elem == nil {
|
||||
return
|
||||
}
|
||||
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
|
||||
if item, exists := c.items[key]; exists && item.element == elem {
|
||||
c.removeItem(key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Defensive forward-progress guard: the back node is dangling — its key is
|
||||
// absent from c.items, or c.items[key] points at a newer node (a stale
|
||||
// duplicate). Drop the node directly so an eviction loop
|
||||
// (`for ... && c.lruList.Len() > 0`) is guaranteed to terminate and can
|
||||
// never spin holding c.mu.Lock(). With the updateLocalCache replace-in-place
|
||||
// fix this branch should be unreachable, but it makes the spin impossible.
|
||||
c.lruList.Remove(elem)
|
||||
if c.currentSize > 0 {
|
||||
c.currentSize--
|
||||
}
|
||||
}
|
||||
|
||||
@@ -936,6 +957,30 @@ func (c *UniversalCache) updateLocalCache(key string, value interface{}, ttl tim
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
// Replace an existing entry in place: update the item and move its single
|
||||
// list node to the front. Without this, a repeat populate of the same key
|
||||
// (the per-request Get->backend-hit path) would PushFront a duplicate node
|
||||
// and overwrite c.items[key], orphaning the previous node. Orphans inflate
|
||||
// currentMemory/currentSize and, once eviction deletes the key, leave a
|
||||
// Back() node whose key is absent from c.items — so evictOldest() spins
|
||||
// while holding c.mu.Lock(): the 100%-CPU write-lock convoy seen in pprof.
|
||||
// setLocal dedups the same way; evictOldest also guards any dangling node.
|
||||
if existing, exists := c.items[key]; exists {
|
||||
c.currentMemory -= existing.Size
|
||||
c.lruList.Remove(existing.element)
|
||||
|
||||
existing.Value = value
|
||||
existing.Size = size
|
||||
existing.ExpiresAt = now.Add(ttl)
|
||||
existing.LastAccessed = now
|
||||
existing.AccessCount++
|
||||
|
||||
existing.element = c.lruList.PushFront(key)
|
||||
c.currentMemory += size
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
item := &CacheItem{
|
||||
Key: key,
|
||||
Value: value,
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// newOrphanTestCache builds a Token-type cache with background cleanup disabled
|
||||
// so the test fully controls lruList/items state.
|
||||
func newOrphanTestCache(maxMem int64) *UniversalCache {
|
||||
return NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
DefaultTTL: time.Hour,
|
||||
MaxSize: 1_000_000, // large: keep the size-branch out of the way
|
||||
MaxMemoryBytes: maxMem,
|
||||
EnableMemoryLimit: maxMem > 0,
|
||||
SkipAutoCleanup: true,
|
||||
EnableAutoCleanup: false,
|
||||
})
|
||||
}
|
||||
|
||||
// TestUpdateLocalCache_NoOrphanElements is the direct red test: repeatedly
|
||||
// populating the SAME key via updateLocalCache (the per-request Get->backend-hit
|
||||
// path) must NOT leave dangling lruList elements. Today updateLocalCache blindly
|
||||
// PushFronts + overwrites c.items[key] without removing the prior element, so the
|
||||
// list grows one orphan per call while items stays at 1 entry.
|
||||
func TestUpdateLocalCache_NoOrphanElements(t *testing.T) {
|
||||
c := newOrphanTestCache(0) // memory limit off: isolate the orphan, no eviction
|
||||
const key = "same-key"
|
||||
|
||||
for range 5 {
|
||||
if err := c.updateLocalCache(key, "v", time.Hour); err != nil {
|
||||
t.Fatalf("updateLocalCache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
listLen := c.lruList.Len()
|
||||
itemCount := len(c.items)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if itemCount != 1 {
|
||||
t.Fatalf("items: got %d want 1", itemCount)
|
||||
}
|
||||
if listLen != 1 {
|
||||
t.Fatalf("ORPHAN BUG: lruList.Len()=%d but items=%d (one list element per key expected)", listLen, itemCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateLocalCache_EvictionTerminates is the convoy reproducer: once orphans
|
||||
// for a key exist and the memory-eviction loop runs, evictOldest() deletes the
|
||||
// key from items on the first eviction, after which every remaining orphan at
|
||||
// Back() has a key absent from items -> evictOldest() no-ops while lruList.Len()>0
|
||||
// stays true -> infinite loop while holding c.mu.Lock(). That is the 100%-CPU
|
||||
// holder + write-lock convoy observed in pprof.
|
||||
func TestUpdateLocalCache_EvictionTerminates(t *testing.T) {
|
||||
c := newOrphanTestCache(0) // start with memory limit OFF to accumulate orphans
|
||||
const key = "same-key"
|
||||
|
||||
// Build 3 same-key list elements (3 orphans, items={key}).
|
||||
for range 3 {
|
||||
if err := c.updateLocalCache(key, "v", time.Hour); err != nil {
|
||||
t.Fatalf("seed updateLocalCache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Arm the trap: tiny memory limit so the next call enters the eviction loop.
|
||||
c.mu.Lock()
|
||||
c.config.MaxMemoryBytes = 1
|
||||
c.mu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = c.updateLocalCache(key, "v", time.Hour) // triggers the eviction loop
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// fix present: loop made forward progress and returned
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("INFINITE LOOP: eviction loop did not terminate within 2s (orphan whose key was deleted is never removed from lruList)")
|
||||
}
|
||||
}
|
||||
+98
-1
@@ -19,7 +19,7 @@ import (
|
||||
// - true if the URL should be excluded from authentication, false otherwise.
|
||||
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
for excludedURL := range t.excludedURLs {
|
||||
if strings.HasPrefix(currentRequest, excludedURL) {
|
||||
if pathExcluded(currentRequest, excludedURL) {
|
||||
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
return true
|
||||
}
|
||||
@@ -27,6 +27,31 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// pathExcluded reports whether requestPath is covered by an excluded prefix at a
|
||||
// natural boundary: an exact match, a sub-path ("/public" → "/public/x"), or a
|
||||
// file extension ("/favicon" → "/favicon.ico"). It deliberately does NOT match
|
||||
// an unrelated sibling such as "/publicsecret", so a configured exclusion can no
|
||||
// longer be widened into an authentication bypass on a different resource.
|
||||
func pathExcluded(requestPath, excluded string) bool {
|
||||
excluded = strings.TrimRight(excluded, "/")
|
||||
if excluded == "" {
|
||||
// A "/" (root) exclusion only matches the root path, not everything.
|
||||
return requestPath == "" || requestPath == "/"
|
||||
}
|
||||
if requestPath == excluded {
|
||||
return true
|
||||
}
|
||||
if !strings.HasPrefix(requestPath, excluded) {
|
||||
return false
|
||||
}
|
||||
switch requestPath[len(excluded)] {
|
||||
case '/', '.':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// buildAuthURL constructs the OIDC provider authorization URL.
|
||||
// It builds the URL with all necessary parameters including client_id, scopes,
|
||||
// PKCE parameters, and provider-specific parameters for Google and Azure.
|
||||
@@ -146,6 +171,21 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
|
||||
}
|
||||
|
||||
// Apply operator-configured extra authorization parameters (e.g.
|
||||
// screen_hint, login_hint, ui_locales, prompt). These are added last but
|
||||
// can never override parameters the plugin itself manages (client_id,
|
||||
// state, nonce, redirect_uri, code_challenge, scope, response_type, ...):
|
||||
// a key already present in params is left untouched, so this cannot
|
||||
// weaken security-critical parameters.
|
||||
for key, value := range t.extraAuthParams {
|
||||
if params.Get(key) == "" {
|
||||
params.Set(key, value)
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: Added extra auth param %s", key)
|
||||
} else {
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: Skipped extra auth param %s (already set by plugin)", key)
|
||||
}
|
||||
}
|
||||
|
||||
// Read authURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
authURL := t.authURL
|
||||
@@ -273,6 +313,63 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateDiscoveredEndpoint validates an endpoint URL obtained from the
|
||||
// provider's OIDC/OAuth2 discovery document before the plugin issues any
|
||||
// outbound request to it. A discovery document is attacker-influenced if the
|
||||
// provider is malicious or its TLS is broken, so an unvalidated endpoint is an
|
||||
// SSRF vector (e.g. jwks_uri or introspection_endpoint pointed at the cloud
|
||||
// metadata service 169.254.169.254 or an internal host).
|
||||
//
|
||||
// Empty endpoints are allowed (they are optional). Link-local (which covers the
|
||||
// 169.254.0.0/16 metadata range), multicast and unspecified addresses are
|
||||
// always rejected. Private addresses are rejected unless allowPrivateIPAddresses
|
||||
// is set. Loopback is rejected unless allowLoopback is true — which the caller
|
||||
// sets only when the operator-configured providerURL is itself loopback (local
|
||||
// development, in-cluster sidecars, tests), so production deployments pointed at
|
||||
// a real provider still block loopback SSRF.
|
||||
func (t *TraefikOidc) validateDiscoveredEndpoint(urlStr string, allowLoopback bool) error {
|
||||
if urlStr == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %w", err)
|
||||
}
|
||||
if u.Scheme != "https" && u.Scheme != "http" {
|
||||
return fmt.Errorf("disallowed URL scheme: %q", u.Scheme)
|
||||
}
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("missing host in URL")
|
||||
}
|
||||
if ip := net.ParseIP(u.Hostname()); ip != nil {
|
||||
switch {
|
||||
case ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsMulticast() || ip.IsUnspecified():
|
||||
return fmt.Errorf("endpoint host is a blocked address: %s", ip)
|
||||
case ip.IsLoopback() && !allowLoopback:
|
||||
return fmt.Errorf("endpoint host is a loopback address: %s", ip)
|
||||
case ip.IsPrivate() && !t.allowPrivateIPAddresses:
|
||||
return fmt.Errorf("endpoint host is a private address: %s", ip)
|
||||
}
|
||||
}
|
||||
if strings.Contains(u.Path, "..") {
|
||||
return fmt.Errorf("path traversal detected in URL path")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sameHost reports whether two URLs share the same host:port (case-insensitive).
|
||||
// Used to pin the credential-bearing introspection endpoint to the operator-
|
||||
// configured provider so a poisoned discovery document cannot redirect the
|
||||
// client secret to an attacker-controlled host.
|
||||
func sameHost(a, b string) bool {
|
||||
ua, erra := url.Parse(a)
|
||||
ub, errb := url.Parse(b)
|
||||
if erra != nil || errb != nil || ua.Host == "" || ub.Host == "" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(ua.Host, ub.Host)
|
||||
}
|
||||
|
||||
// validateHost validates a hostname or IP address for security.
|
||||
// It prevents access to localhost, private networks, and known metadata endpoints.
|
||||
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
|
||||
|
||||
@@ -554,3 +554,54 @@ func TestForceHTTPSIntegration(t *testing.T) {
|
||||
"should use https from X-Forwarded-Proto when forceHTTPS is false")
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildAuthURLExtraAuthParams verifies operator-configured extra
|
||||
// authorization parameters are appended to the authorization URL, and that
|
||||
// they can never override parameters the plugin itself manages.
|
||||
func TestBuildAuthURLExtraAuthParams(t *testing.T) {
|
||||
t.Run("extra params are added (e.g. screen_hint=signup)", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.extraAuthParams = map[string]string{
|
||||
"screen_hint": "signup",
|
||||
"ui_locales": "en",
|
||||
}
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback", "state123", "nonce456", "",
|
||||
)
|
||||
|
||||
assert.Contains(t, authURL, "screen_hint=signup")
|
||||
assert.Contains(t, authURL, "ui_locales=en")
|
||||
})
|
||||
|
||||
t.Run("nil/empty extraAuthParams is a no-op", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
// extraAuthParams left nil
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback", "state123", "nonce456", "",
|
||||
)
|
||||
|
||||
assert.Contains(t, authURL, "client_id=test-client")
|
||||
assert.NotContains(t, authURL, "screen_hint")
|
||||
})
|
||||
|
||||
t.Run("extra params CANNOT override plugin-managed params", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.extraAuthParams = map[string]string{
|
||||
"client_id": "ATTACKER",
|
||||
"state": "ATTACKER",
|
||||
"redirect_uri": "https://evil.example.com",
|
||||
"response_type": "token",
|
||||
}
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback", "state123", "nonce456", "",
|
||||
)
|
||||
|
||||
// Plugin-managed values must win; injected values must be absent.
|
||||
assert.Contains(t, authURL, "client_id=test-client")
|
||||
assert.NotContains(t, authURL, "ATTACKER")
|
||||
assert.NotContains(t, authURL, "evil.example.com")
|
||||
assert.Contains(t, authURL, "response_type=code")
|
||||
})
|
||||
}
|
||||
|
||||
+8
-3
@@ -135,7 +135,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
domain := parts[1]
|
||||
domain := strings.ToLower(parts[1])
|
||||
_, domainAllowed := t.allowedUserDomains[domain]
|
||||
|
||||
if domainAllowed {
|
||||
@@ -236,8 +236,13 @@ func (t *TraefikOidc) Close() error {
|
||||
// Get resource manager for cleanup
|
||||
rm := GetResourceManager()
|
||||
|
||||
// Stop singleton tasks related to this instance
|
||||
_ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup
|
||||
// singleton-token-cleanup is a process-global task shared by every plugin
|
||||
// instance. Only stop it when the LAST instance is shutting down;
|
||||
// otherwise one instance's teardown (e.g. a single config reload) would
|
||||
// kill chunked-session/token cleanup for all surviving instances (rank 12).
|
||||
if unregisterLiveInstance() <= 0 {
|
||||
_ = rm.StopBackgroundTask("singleton-token-cleanup") // best effort, last instance only
|
||||
}
|
||||
// Stop metadata refresh task using same hash-based name as startMetadataRefresh
|
||||
if t.providerURL != "" {
|
||||
hash := sha256.Sum256([]byte(t.providerURL))
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
.docs
|
||||
+36
@@ -0,0 +1,36 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
timeout: 2m
|
||||
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- bodyclose
|
||||
- errcheck
|
||||
- errorlint
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- govet
|
||||
- ineffassign
|
||||
- misspell
|
||||
- prealloc
|
||||
- revive
|
||||
- staticcheck
|
||||
- unconvert
|
||||
- unused
|
||||
settings:
|
||||
gocyclo:
|
||||
min-complexity: 12
|
||||
revive:
|
||||
rules:
|
||||
- name: var-naming
|
||||
- name: indent-error-flow
|
||||
- name: superfluous-else
|
||||
- name: unused-parameter
|
||||
- name: redefines-builtin-id
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- goimports
|
||||
+42
@@ -0,0 +1,42 @@
|
||||
# Configuration for lukaszraczylo/semver-generator.
|
||||
# Reference: https://github.com/lukaszraczylo/semver-generator
|
||||
#
|
||||
# Word matching is fuzzy + case-insensitive. The keywords below mirror the
|
||||
# Conventional Commits prefixes used in this repo's git history. Same pattern
|
||||
# as github.com/lukaszraczylo/go-telegram/.semver.yaml.
|
||||
|
||||
version: 1
|
||||
|
||||
# Respect existing v* tags as the version baseline. semver-generator finds
|
||||
# the highest existing tag and bumps from there. With no tags yet, the first
|
||||
# release computes from the empty base.
|
||||
force:
|
||||
existing: true
|
||||
|
||||
# Skip merge commits and machine-generated traffic that would otherwise
|
||||
# spuriously bump the version.
|
||||
blacklist:
|
||||
- "Merge branch"
|
||||
- "Merge pull request"
|
||||
- "Merge remote-tracking branch"
|
||||
- "go mod tidy"
|
||||
|
||||
wording:
|
||||
patch:
|
||||
- "fix"
|
||||
- "chore"
|
||||
- "docs"
|
||||
- "test"
|
||||
- "style"
|
||||
- "refactor"
|
||||
- "build"
|
||||
- "ci"
|
||||
- "perf"
|
||||
minor:
|
||||
- "feat"
|
||||
major:
|
||||
# Match only the canonical Conventional Commits trailer. The bare word
|
||||
# "breaking" is too greedy under semver-generator's fuzzy match — it
|
||||
# triggers on substrings inside a commit body and wrongly produces a
|
||||
# major bump.
|
||||
- "BREAKING CHANGE"
|
||||
+122
@@ -0,0 +1,122 @@
|
||||
# oss-telemetry
|
||||
|
||||
A tiny Go client that fires one anonymous "this binary started" ping at a
|
||||
central ingest endpoint. Designed to be embedded in your own open-source
|
||||
projects so you can see approximate adoption and version spread without
|
||||
collecting anything that could identify a user.
|
||||
|
||||
This is the **client library only**. The ingest endpoint, server-side
|
||||
deduplication, rate limiting, and metrics are out of scope here.
|
||||
|
||||
## What it sends
|
||||
|
||||
A single HTTP `POST` per call to `Send`:
|
||||
|
||||
```json
|
||||
{
|
||||
"project": "my-tool",
|
||||
"version": "1.2.3",
|
||||
"ts": 1747782200
|
||||
}
|
||||
```
|
||||
|
||||
No identifiers, no IP, no machine info, no user data. The server dedupes
|
||||
incoming requests; the client just fires and forgets.
|
||||
|
||||
## Failproof by design
|
||||
|
||||
- Never blocks the caller — work runs in a goroutine.
|
||||
- Never panics — the goroutine recovers internally.
|
||||
- Never returns errors — bad input and network failures are silently dropped.
|
||||
- Never retries, never persists state, never reads back.
|
||||
- 2-second hard timeout on every request.
|
||||
- Zero third-party dependencies (Go stdlib only).
|
||||
|
||||
The endpoint is hardcoded and not overridable from consuming code, by design.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
go get github.com/lukaszraczylo/oss-telemetry
|
||||
```
|
||||
|
||||
Requires Go 1.22+.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
telemetry "github.com/lukaszraczylo/oss-telemetry"
|
||||
)
|
||||
|
||||
const version = "1.2.3"
|
||||
|
||||
func main() {
|
||||
telemetry.Send("my-tool", version)
|
||||
|
||||
// ... your program runs ...
|
||||
|
||||
// Only needed for short-lived CLIs that may exit before the goroutine
|
||||
// finishes its POST. Long-running services do not need this.
|
||||
telemetry.Wait(2 * time.Second)
|
||||
}
|
||||
```
|
||||
|
||||
Call `Send` once at boot. Calling it more often just sends more pings; the
|
||||
server deduplicates.
|
||||
|
||||
## Disabling telemetry
|
||||
|
||||
If you ship a binary that imports this library, link your users to this
|
||||
section (`https://github.com/lukaszraczylo/oss-telemetry#disabling-telemetry`)
|
||||
so they can find the opt-out paths.
|
||||
|
||||
Any one of these turns it off:
|
||||
|
||||
| Mechanism | How |
|
||||
| ---------------------------------------- | ---------------------------------------------------------------- |
|
||||
| Universal opt-out | `DO_NOT_TRACK=1` |
|
||||
| Library-wide opt-out | `OSS_TELEMETRY_DISABLED=1` |
|
||||
| Project-specific opt-out | `<UPPER_PROJECT>_DISABLE_TELEMETRY=1` |
|
||||
| Programmatic (e.g. behind a `--no-telemetry` flag) | `telemetry.Disable()` before the first `Send` |
|
||||
|
||||
Project-specific env var derivation: uppercase the project name and replace
|
||||
`-` with `_`. For `my-tool` the variable is `MY_TOOL_DISABLE_TELEMETRY`.
|
||||
|
||||
Truthy values: `1`, `true`, `yes`, `on` (case-insensitive). Anything else is
|
||||
treated as "not set".
|
||||
|
||||
## Validation rules (silently dropped if violated)
|
||||
|
||||
- `project`: matches `^[a-z0-9-]+$`, length 1–64.
|
||||
- `version`: matches `^[A-Za-z0-9.+_-]+$`, length 1–32.
|
||||
|
||||
Bad input is a no-op — the library never logs, never errors, never crashes.
|
||||
|
||||
## API
|
||||
|
||||
```go
|
||||
// Fire a single ping in the background. Returns immediately.
|
||||
func Send(project, version string)
|
||||
|
||||
// Suppress all subsequent Send calls in this process. Idempotent.
|
||||
func Disable()
|
||||
|
||||
// Block until in-flight pings complete or timeout elapses, whichever first.
|
||||
// Useful for short-lived CLI processes.
|
||||
func Wait(timeout time.Duration)
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
go test -race ./...
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Pick one before publishing. None bundled.
|
||||
+367
@@ -0,0 +1,367 @@
|
||||
// Package telemetry sends anonymous usage pings for open-source Go projects.
|
||||
//
|
||||
// Wire format (POST application/json):
|
||||
//
|
||||
// {"project":"<name>","version":"<ver>","ts":<unix-seconds>}
|
||||
//
|
||||
// Design contract (failproof):
|
||||
// - never blocks the caller (work happens in a goroutine)
|
||||
// - never panics (background goroutine recovers internally)
|
||||
// - never returns errors (silently no-ops on bad input or network failure)
|
||||
// - never retries, never deduplicates, never persists state — the client
|
||||
// fires a single ping and forgets; the server is responsible for
|
||||
// deduplication, abuse protection, and aggregation
|
||||
//
|
||||
// Typical usage at program startup:
|
||||
//
|
||||
// telemetry.Send("my-tool", "1.2.3")
|
||||
//
|
||||
// For short-lived CLI processes that may exit before the goroutine finishes:
|
||||
//
|
||||
// telemetry.Send("my-tool", "1.2.3")
|
||||
// defer telemetry.Wait(2 * time.Second)
|
||||
//
|
||||
// Disablement (any one of these suppresses pings):
|
||||
// - environment variable DO_NOT_TRACK=1
|
||||
// - environment variable OSS_TELEMETRY_DISABLED=1
|
||||
// - environment variable <UPPER_PROJECT>_DISABLE_TELEMETRY=1
|
||||
// (project name uppercased, dashes replaced with underscores)
|
||||
// - calling telemetry.Disable() at runtime
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultEndpoint = "https://oss.raczylo.com/v1/ping"
|
||||
httpTimeout = 2 * time.Second
|
||||
maxProjectLen = 64
|
||||
maxVersionLen = 32
|
||||
)
|
||||
|
||||
// Yaegi note: this package is consumed by the traefikoidc Traefik plugin, which
|
||||
// Traefik interprets with Yaegi (it vendors and interprets dependency source).
|
||||
// It therefore avoids generic stdlib types (atomic.Pointer[T], atomic.Bool) and
|
||||
// range-over-int (Go 1.22), which some Traefik/Yaegi runtimes cannot interpret.
|
||||
// Endpoint mutation uses a mutex-guarded string; the disabled flag uses the
|
||||
// function-based sync/atomic int32 API (atomic.LoadInt32/StoreInt32).
|
||||
var (
|
||||
// endpointURL holds the ingest URL. Production code never mutates it; the
|
||||
// setter exists only so the test suite can retarget it at httptest servers
|
||||
// while goroutines started by Send are still in flight.
|
||||
endpointMu sync.RWMutex
|
||||
endpointURL = defaultEndpoint
|
||||
|
||||
disabled int32 // 0 = enabled, 1 = disabled; accessed via sync/atomic only
|
||||
inflight sync.WaitGroup
|
||||
|
||||
client = &http.Client{Timeout: httpTimeout}
|
||||
)
|
||||
|
||||
func currentEndpoint() string {
|
||||
endpointMu.RLock()
|
||||
defer endpointMu.RUnlock()
|
||||
return endpointURL
|
||||
}
|
||||
|
||||
func setEndpointURL(u string) {
|
||||
endpointMu.Lock()
|
||||
endpointURL = u
|
||||
endpointMu.Unlock()
|
||||
}
|
||||
|
||||
// Send fires a single anonymous telemetry ping in the background and returns
|
||||
// immediately. It never blocks, never panics, and never reports errors.
|
||||
// Invalid inputs, disabled state, and network failures are silently dropped.
|
||||
//
|
||||
// Version strings are validated against a SemVer-ish shape that mirrors the
|
||||
// receiver. An optional leading "v" or "V" is accepted and stripped before
|
||||
// transmission so that callers can pass either "v1.2.3" or "1.2.3"; the
|
||||
// wire form is always the unprefixed canonical version.
|
||||
//
|
||||
// Call once at program startup. Calling repeatedly will send repeated pings;
|
||||
// the server is responsible for deduplication.
|
||||
func Send(project, version string) {
|
||||
if atomic.LoadInt32(&disabled) != 0 {
|
||||
return
|
||||
}
|
||||
if isDisabledByEnv(project) {
|
||||
return
|
||||
}
|
||||
if !validProject(project) || !validVersion(version) {
|
||||
return
|
||||
}
|
||||
canonical := normalizeVersion(version)
|
||||
|
||||
inflight.Add(1)
|
||||
go func() {
|
||||
defer inflight.Done()
|
||||
defer func() { _ = recover() }()
|
||||
dispatch(project, canonical)
|
||||
}()
|
||||
}
|
||||
|
||||
// SendForModule is the recommended call form for Go libraries: it resolves
|
||||
// the version automatically from Go's build info for the given module path
|
||||
// so consumers do not need to maintain a hand-bumped version constant in
|
||||
// source. Behaviour and contract are otherwise identical to [Send].
|
||||
//
|
||||
// Resolution order:
|
||||
//
|
||||
// 1. debug.ReadBuildInfo Deps entry for modulePath (authoritative when the
|
||||
// library is consumed via go.mod);
|
||||
// 2. debug.ReadBuildInfo Main when the library is itself the main module
|
||||
// (e.g. running its own tests or examples);
|
||||
// 3. fallback parameter, used only when build info is unavailable or
|
||||
// unhelpful (replace directives, detached `go run`, ldflag override).
|
||||
//
|
||||
// Any leading "v" reported by build info is stripped to match the canonical
|
||||
// wire form. Empty / "(devel)" build versions are skipped in favour of the
|
||||
// next resolution source. Typical usage:
|
||||
//
|
||||
// telemetry.SendForModule("my-tool", "github.com/me/my-tool", "0.0.0-dev")
|
||||
func SendForModule(project, modulePath, fallback string) {
|
||||
Send(project, ResolveModuleVersion(modulePath, fallback))
|
||||
}
|
||||
|
||||
// ResolveModuleVersion implements the version resolution used by
|
||||
// SendForModule. Exposed for callers that need to format the resolved
|
||||
// version (e.g. logging) without firing a ping.
|
||||
func ResolveModuleVersion(modulePath, fallback string) string {
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
for _, d := range info.Deps {
|
||||
if d != nil && d.Path == modulePath && isUsableBuildVersion(d.Version) {
|
||||
return strings.TrimPrefix(d.Version, "v")
|
||||
}
|
||||
}
|
||||
if info.Main.Path == modulePath && isUsableBuildVersion(info.Main.Version) {
|
||||
return strings.TrimPrefix(info.Main.Version, "v")
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func isUsableBuildVersion(v string) bool {
|
||||
return v != "" && v != "(devel)"
|
||||
}
|
||||
|
||||
// Disable suppresses all subsequent Send calls in this process.
|
||||
// Idempotent and safe to call from any goroutine.
|
||||
func Disable() {
|
||||
atomic.StoreInt32(&disabled, 1)
|
||||
}
|
||||
|
||||
// Wait blocks until all in-flight pings have completed, or until timeout
|
||||
// elapses — whichever comes first. Useful for short-lived CLI processes
|
||||
// that may otherwise exit before the background goroutine finishes its POST.
|
||||
//
|
||||
// A non-positive timeout returns immediately.
|
||||
func Wait(timeout time.Duration) {
|
||||
if timeout <= 0 {
|
||||
return
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
inflight.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
|
||||
func dispatch(project, version string) {
|
||||
body := buildPayload(project, version, time.Now().Unix())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), httpTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, currentEndpoint(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
// buildPayload writes the JSON body without encoding/json. The validators
|
||||
// restrict project and version to characters that never require JSON
|
||||
// escaping, so direct concatenation is safe.
|
||||
func buildPayload(project, version string, ts int64) []byte {
|
||||
// Wrapper text plus 20 chars for a signed int64.
|
||||
const overhead = len(`{"project":"","version":"","ts":}`) + 20
|
||||
buf := make([]byte, 0, len(project)+len(version)+overhead)
|
||||
buf = append(buf, `{"project":"`...)
|
||||
buf = append(buf, project...)
|
||||
buf = append(buf, `","version":"`...)
|
||||
buf = append(buf, version...)
|
||||
buf = append(buf, `","ts":`...)
|
||||
buf = strconv.AppendInt(buf, ts, 10)
|
||||
buf = append(buf, '}')
|
||||
return buf
|
||||
}
|
||||
|
||||
func validProject(p string) bool {
|
||||
n := len(p)
|
||||
if n == 0 || n > maxProjectLen {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
c := p[i]
|
||||
switch {
|
||||
case c >= 'a' && c <= 'z',
|
||||
c >= '0' && c <= '9',
|
||||
c == '-':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// validVersion accepts SemVer-ish version strings with an optional leading
|
||||
// "v"/"V" prefix. Acceptable shape (after stripping the leading v):
|
||||
//
|
||||
// MAJOR[.MINOR[.PATCH]] ("-"prerelease)? ("+"build)?
|
||||
//
|
||||
// where MAJOR/MINOR/PATCH are ASCII digit sequences and the prerelease/build
|
||||
// payloads are non-empty runs of [0-9A-Za-z.-]. This intentionally mirrors
|
||||
// the receiver's version regex so junk like "dev" or "git-2026-05-22" never
|
||||
// leaves the client (where it would only be rejected with HTTP 400 anyway).
|
||||
func validVersion(v string) bool {
|
||||
n := len(v)
|
||||
if n == 0 || n > maxVersionLen {
|
||||
return false
|
||||
}
|
||||
if v[0] == 'v' || v[0] == 'V' {
|
||||
v = v[1:]
|
||||
}
|
||||
if len(v) == 0 {
|
||||
return false
|
||||
}
|
||||
return checkSemverShape(v)
|
||||
}
|
||||
|
||||
// normalizeVersion strips an optional leading "v"/"V" so the on-the-wire
|
||||
// version matches the form stored server-side by the version refresher cron
|
||||
// (which also strips the leading v from release tags). Callers may pass
|
||||
// either "v1.2.3" or "1.2.3" — only the unprefixed form is transmitted.
|
||||
func normalizeVersion(v string) string {
|
||||
if len(v) > 0 && (v[0] == 'v' || v[0] == 'V') {
|
||||
return v[1:]
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func checkSemverShape(s string) bool {
|
||||
i := 0
|
||||
if !readDigitRun(s, &i) {
|
||||
return false
|
||||
}
|
||||
for groups := 0; groups < 2 && i < len(s) && s[i] == '.'; groups++ {
|
||||
i++
|
||||
if !readDigitRun(s, &i) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if i < len(s) && s[i] == '-' {
|
||||
i++
|
||||
if !readIdentRun(s, &i, '+') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if i < len(s) && s[i] == '+' {
|
||||
i++
|
||||
if !readIdentRun(s, &i, 0) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return i == len(s)
|
||||
}
|
||||
|
||||
func readDigitRun(s string, i *int) bool {
|
||||
start := *i
|
||||
for *i < len(s) && s[*i] >= '0' && s[*i] <= '9' {
|
||||
*i++
|
||||
}
|
||||
return *i > start
|
||||
}
|
||||
|
||||
// readIdentRun consumes [0-9A-Za-z.-] until end-of-string or until `stop`
|
||||
// is hit (stop=0 disables the early-stop check). Returns false if no
|
||||
// characters were consumed (i.e. empty payload).
|
||||
func readIdentRun(s string, i *int, stop byte) bool {
|
||||
start := *i
|
||||
for *i < len(s) {
|
||||
c := s[*i]
|
||||
if stop != 0 && c == stop {
|
||||
break
|
||||
}
|
||||
valid := (c >= '0' && c <= '9') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
c == '.' || c == '-'
|
||||
if !valid {
|
||||
return false
|
||||
}
|
||||
*i++
|
||||
}
|
||||
return *i > start
|
||||
}
|
||||
|
||||
func isDisabledByEnv(project string) bool {
|
||||
if truthy(os.Getenv("DO_NOT_TRACK")) {
|
||||
return true
|
||||
}
|
||||
if truthy(os.Getenv("OSS_TELEMETRY_DISABLED")) {
|
||||
return true
|
||||
}
|
||||
if project == "" {
|
||||
return false
|
||||
}
|
||||
key := projectEnvKey(project)
|
||||
return truthy(os.Getenv(key))
|
||||
}
|
||||
|
||||
// projectEnvKey returns "<UPPER_PROJECT>_DISABLE_TELEMETRY" using a single
|
||||
// allocation rather than chained strings.ToUpper(strings.ReplaceAll(...)).
|
||||
func projectEnvKey(project string) string {
|
||||
const suffix = "_DISABLE_TELEMETRY"
|
||||
buf := make([]byte, 0, len(project)+len(suffix))
|
||||
for i := 0; i < len(project); i++ {
|
||||
c := project[i]
|
||||
switch {
|
||||
case c == '-':
|
||||
c = '_'
|
||||
case c >= 'a' && c <= 'z':
|
||||
c -= 'a' - 'A'
|
||||
}
|
||||
buf = append(buf, c)
|
||||
}
|
||||
buf = append(buf, suffix...)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func truthy(s string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
Vendored
+3
@@ -24,6 +24,9 @@ github.com/gorilla/securecookie
|
||||
# github.com/gorilla/sessions v1.3.0
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/sessions
|
||||
# github.com/lukaszraczylo/oss-telemetry v0.2.3
|
||||
## explicit; go 1.22
|
||||
github.com/lukaszraczylo/oss-telemetry
|
||||
# github.com/pmezard/go-difflib v1.0.0
|
||||
## explicit
|
||||
github.com/pmezard/go-difflib/difflib
|
||||
|
||||
+17
@@ -0,0 +1,17 @@
|
||||
package traefikoidc
|
||||
|
||||
// devPluginVersion is the placeholder carried by source-tree / local / test
|
||||
// builds. Telemetry is suppressed while the plugin still reports this sentinel,
|
||||
// so only stamped release builds emit a "plugin loaded" ping.
|
||||
const devPluginVersion = "0.0.0-dev"
|
||||
|
||||
// traefikoidcPluginVersion is the released version of this plugin. It is stamped
|
||||
// at release time by ./workflow-prepare.sh (invoked by the shared go-release
|
||||
// workflow before GoReleaser builds and tags), which rewrites the string below
|
||||
// to the computed semver.
|
||||
//
|
||||
// Traefik runs this plugin under Yaegi, where the version cannot be resolved
|
||||
// from build info at runtime (debug.ReadBuildInfo sees Traefik's build graph,
|
||||
// not the interpreted plugin). This build-stamped constant is therefore the
|
||||
// single source of truth for the version reported by anonymous usage telemetry.
|
||||
const traefikoidcPluginVersion = "0.0.0-dev"
|
||||
Executable
+67
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# workflow-prepare.sh — stamp the release version into version.go at build time.
|
||||
#
|
||||
# The shared go-release workflow (lukaszraczylo/shared-actions go-release.yaml)
|
||||
# runs this script, if present, from the repository root BEFORE GoReleaser
|
||||
# builds and tags. Traefik runs this plugin under Yaegi, where the version
|
||||
# cannot be resolved from build info at runtime, so the released semver must be
|
||||
# baked into source here.
|
||||
#
|
||||
# Version source — first non-empty wins:
|
||||
# $VERSION $VERSION_TAG $SEMVER $NEW_VERSION $RELEASE_VERSION
|
||||
# A leading "v"/"V" is stripped.
|
||||
#
|
||||
# NOTE: go-release.yaml @main does not yet pass the computed version into this
|
||||
# step's environment. Add it to the "Run workflow prepare script" step, e.g.:
|
||||
# env:
|
||||
# VERSION: ${{ needs.version.outputs.version }} # bare, no leading v
|
||||
#
|
||||
# The shared workflow runs this script in its test, version AND release jobs,
|
||||
# but only the release job has a computed version. So a missing version is a
|
||||
# no-op (leave the dev sentinel) — NOT a hard failure, otherwise the test/version
|
||||
# jobs would break. A malformed version that IS provided is a hard error. Wire
|
||||
# the env only on the release job's prepare step (see header note above).
|
||||
set -euo pipefail
|
||||
|
||||
FILE="version.go"
|
||||
CONST="traefikoidcPluginVersion"
|
||||
|
||||
VER="${VERSION:-${VERSION_TAG:-${SEMVER:-${NEW_VERSION:-${RELEASE_VERSION:-}}}}}"
|
||||
VER="${VER#v}"
|
||||
VER="${VER#V}"
|
||||
|
||||
if [ -z "$VER" ]; then
|
||||
if [ "${GITHUB_ACTIONS:-}" = "true" ]; then
|
||||
echo "workflow-prepare: WARNING no version provided; leaving ${FILE} at the dev placeholder. If this is the release build, set 'env: VERSION: \${{ needs.version.outputs.version }}' on the release job's prepare step — otherwise the release ships 0.0.0-dev and emits no telemetry." >&2
|
||||
else
|
||||
echo "workflow-prepare: no version provided; leaving dev placeholder in ${FILE} (local build)"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Accept MAJOR[.MINOR[.PATCH]] with optional -prerelease / +build (semver-ish,
|
||||
# matching the oss-telemetry receiver's validator).
|
||||
if ! printf '%s' "$VER" | grep -Eq '^[0-9]+(\.[0-9]+){0,2}(-[0-9A-Za-z.-]+)?(\+[0-9A-Za-z.-]+)?$'; then
|
||||
echo "workflow-prepare: ERROR version '${VER}' is not semver-shaped" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$FILE" ]; then
|
||||
echo "workflow-prepare: ERROR ${FILE} not found (run from repository root)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Rewrite only the value of ${CONST}, anchored on the constant name so the
|
||||
# sibling devPluginVersion sentinel is left untouched.
|
||||
tmp="$(mktemp)"
|
||||
sed -E "s/(${CONST}[[:space:]]*=[[:space:]]*\")[^\"]*(\")/\1${VER}\2/" "$FILE" > "$tmp"
|
||||
mv "$tmp" "$FILE"
|
||||
|
||||
if ! grep -Eq "${CONST}[[:space:]]*=[[:space:]]*\"${VER}\"" "$FILE"; then
|
||||
echo "workflow-prepare: ERROR failed to stamp version into ${FILE}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
command -v gofmt >/dev/null 2>&1 && gofmt -w "$FILE"
|
||||
echo "workflow-prepare: stamped ${CONST} = \"${VER}\" in ${FILE}"
|
||||
Reference in New Issue
Block a user