Compare commits

..

1 Commits

Author SHA1 Message Date
Hermes Agent 227de89d33 feat: add cookiePath config to scope session cookies to subpath
Fixes #122.
2026-05-27 21:43:20 +01:00
56 changed files with 2809 additions and 2251 deletions
+1 -1
View File
@@ -18,6 +18,6 @@ jobs:
pr-checks:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with:
go-version: "1.25.x"
go-version: "1.24.11"
coverage-threshold: 70
secrets: inherit
+1 -1
View File
@@ -19,5 +19,5 @@ jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.25.x"
go-version: "1.24.11"
secrets: inherit
-61
View File
@@ -1,61 +0,0 @@
# traefikoidc — Makefile
# Run `make help` for available targets.
GO ?= go
GOPATH := $(shell $(GO) env GOPATH)
# Pin to the yaegi version bundled by the deployed Traefik so yaegi-validate
# tests the real interpreter, not a newer one that may support more. Traefik
# v3.7.1 vendors yaegi v0.16.1 (Go ~1.22 stdlib surface). Bump when Traefik is.
YAEGI_VERSION ?= v0.16.1
TEST_TIMEOUT ?= 480s
.DEFAULT_GOAL := help
.PHONY: help
help: ## Show this help
@grep -hE '^[a-zA-Z0-9_-]+:.*## ' $(MAKEFILE_LIST) | awk 'BEGIN{FS=":.*## "}{printf " \033[36m%-16s\033[0m %s\n", $$1, $$2}'
.PHONY: build
build: ## Compile all packages (native toolchain)
$(GO) build ./...
.PHONY: fmt
fmt: ## Format sources with gofmt
gofmt -w $$(git ls-files '*.go' | grep -v '^vendor/')
.PHONY: vet
vet: ## Run go vet
$(GO) vet ./...
.PHONY: lint
lint: ## Run golangci-lint if available
@command -v golangci-lint >/dev/null 2>&1 && golangci-lint run ./... || echo "golangci-lint not installed; skipping"
.PHONY: staticcheck
staticcheck: ## Run staticcheck (matches the CI "Static Analysis" job; catches U1000 unused, etc.)
@command -v staticcheck >/dev/null 2>&1 || { echo ">> installing staticcheck"; $(GO) install honnef.co/go/tools/cmd/staticcheck@latest; }
@GOFLAGS=-buildvcs=false $$(command -v staticcheck || echo "$(GOPATH)/bin/staticcheck") ./...
.PHONY: test
test: ## Run the test suite
$(GO) test ./... -count=1 -timeout $(TEST_TIMEOUT)
.PHONY: vendor
vendor: ## Refresh and vendor dependencies
$(GO) mod tidy && $(GO) mod vendor
# yaegi-validate interprets the plugin under the yaegi interpreter the same way
# Traefik loads it. Native `go build`/`go test` use the standard compiler and do
# NOT catch yaegi-only incompatibilities (unsupported stdlib symbols, reflection
# edge cases). This target does. Importing the package forces yaegi to interpret
# every file in it plus its vendored deps; CreateConfig + New exercise the
# instantiation path. Pin YAEGI_VERSION to match Traefik's bundled yaegi if you
# need exact parity.
.PHONY: yaegi-validate
yaegi-validate: ## Verify the plugin loads under Traefik's yaegi interpreter
@command -v yaegi >/dev/null 2>&1 || { echo ">> installing yaegi@$(YAEGI_VERSION)"; $(GO) install github.com/traefik/yaegi/cmd/yaegi@$(YAEGI_VERSION); }
@echo ">> interpreting plugin under yaegi (as Traefik does)"
@DO_NOT_TRACK=1 GOFLAGS=-mod=vendor $$(command -v yaegi || echo "$(GOPATH)/bin/yaegi") run ./cmd/yaegicheck/main.go
.PHONY: check
check: vet staticcheck test yaegi-validate ## vet + staticcheck + tests + yaegi load validation
+2 -15
View File
@@ -112,7 +112,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
| `postLogoutRedirectURI` | `/` | Where to send users after logout. |
| `scopes` | appended to `openid profile email` | Extra OAuth scopes. Set `overrideScopes: true` to replace defaults. |
| `extraAuthParams` | none | Map of extra query parameters appended to the authorization request (e.g. `screen_hint: signup`, `login_hint`, `ui_locales`, `prompt`). Plugin-managed params (`client_id`, `state`, `nonce`, `redirect_uri`, `code_challenge`, `scope`, `response_type`, …) cannot be overridden. |
| `excludedURLs` | none | Paths that bypass auth, matched at a path-segment or file-extension boundary (e.g. `/public` matches `/public`, `/public/sub` and `/public.json`, but **not** `/publicsecret`). |
| `excludedURLs` | none | Prefix-matched paths that bypass auth. |
| `allowedUserDomains` | none | Restrict to email domains. |
| `allowedUsers` | none | Restrict to specific addresses (or claim values when `userIdentifierClaim != email`). |
| `allowedRolesAndGroups` | none | Require any of these roles/groups from ID-token claims. |
@@ -121,6 +121,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
| `enablePKCE` | `false` | PKCE on the auth code flow. |
| `cookieDomain` | auto | Set explicitly for multi-subdomain setups (`.example.com`). |
| `cookiePrefix` | `_oidc_raczylo_` | Unique prefix per middleware instance to isolate sessions. |
| `cookiePath` | `/` | Restrict cookies to a path prefix. Set to the middleware's path (e.g. `/app`) to prevent the browser from sending OIDC cookies to unprotected paths, avoiding 431 "Request Header Or Cookie Too Large" errors on mixed-use domains. |
| `sessionMaxAge` | `86400` | Session lifetime in seconds. |
| `refreshGracePeriodSeconds` | `60` | Proactively refresh tokens this many seconds before expiry. |
| `maxRefreshTokenAgeSeconds` | `21600` | Heuristic max stored refresh-token lifetime (6h). Past this, the plugin treats the RT as expired without contacting the IdP — returns 401 to AJAX, full re-auth on navigations. Set `0` to disable. Tune to match your IdP's RT TTL. |
@@ -147,18 +148,6 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
## Production gotchas
### Upgrading from an earlier release
- **Sessions are re-issued once.** Session cookies are now AES-256 encrypted
(previously signed only) and their cryptographic lifetime tracks
`sessionMaxAge` (previously a fixed 30 days). Existing cookies become invalid
on upgrade, so users re-authenticate one time.
- **Invalid configuration now fails closed at startup** instead of being
silently accepted: a `sessionEncryptionKey` shorter than 32 bytes, a
`rateLimit` below 10, a missing `callbackURL`, or a non-HTTPS remote
`providerURL` are rejected. Plaintext HTTP is permitted only for loopback
hosts (local development).
### TLS termination at a load balancer
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
@@ -178,8 +167,6 @@ detected" when the same token hits different replicas. Two options:
For IdP-initiated logout (back/front-channel) in multi-replica setups, Redis is
**required** so a logout on one instance invalidates sessions on the others.
Front-channel logout requests must include a matching `iss` query parameter;
requests that omit it are rejected with `400`.
### Multiple middleware instances on the same host
+1 -9
View File
@@ -182,11 +182,6 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
}
codeVerifier := session.GetCodeVerifier()
if t.enablePKCE && codeVerifier == "" {
t.logger.Error("PKCE is enabled but code verifier is missing from session during callback")
t.sendErrorResponse(rw, req, "Authentication failed: PKCE verifier missing", http.StatusBadRequest)
return
}
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
@@ -268,10 +263,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
// Neutralize open-redirect payloads (e.g. //evil.com, /\evil.com) stored
// from the original request target before using it as the post-login
// redirect target. normalizeLogoutPath forces a host-relative path.
redirectPath = normalizeLogoutPath(incomingPath)
redirectPath = incomingPath
}
session.SetIncomingPath("")
+10 -101
View File
@@ -149,94 +149,6 @@ func parseBearerJOSEHeader(token string) *bearerError {
return nil
}
// headerClaimRuneReason reports why a rune is unsafe to inject into a request
// header value, or "" if the rune is acceptable. Shared core of the bearer-path
// identifier sanitizer and the cookie-path header claim sanitizer: rejects
// control chars (CRLF/header injection), Unicode bidi-override runes (RTL
// spoofing of admin UI / SIEM), and the delimiters , ; = (a comma in a group
// name would inject extra entries into a comma-joined header).
func headerClaimRuneReason(r rune) string {
if reason := headerInjectionRuneReason(r); reason != "" {
return reason
}
// The , ; = delimiters are only unsafe for values placed into delimited or
// list contexts (a comma-joined header, or an identifier downstreams may
// split). They are valid in arbitrary single header values, so this stricter
// check is used for the cookie-path identifier and the group/role list, NOT
// for free-form templated header output (see headerValueReason).
if r == ',' || r == ';' || r == '=' {
return "delimiter character"
}
return ""
}
// headerInjectionRuneReason reports why a rune is unsafe in ANY HTTP header
// value, or "" if acceptable. Rejects control characters (CR/LF header
// injection) and Unicode bidi-override runes (RTL spoofing of admin UIs/SIEMs).
// Unlike headerClaimRuneReason it does NOT reject , ; = which are legitimate in
// free-form header values (e.g. an opaque "Authorization: Bearer <token>").
func headerInjectionRuneReason(r rune) string {
if unicode.IsControl(r) {
return "control character"
}
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
return "bidi-override character"
}
return ""
}
// headerValueReason reports why value is unsafe to forward as a free-form HTTP
// header value, or "" if acceptable. It rejects values over maxLen (maxLen<=0
// disables the check) and values containing control or bidi-override runes, but
// permits , ; = (valid in header values). Empty is allowed. The reason string
// never includes the value, so it is safe to log.
func headerValueReason(value string, maxLen int) string {
if maxLen > 0 && len(value) > maxLen {
return "exceeds max length"
}
for _, r := range value {
if reason := headerInjectionRuneReason(r); reason != "" {
return reason
}
}
return ""
}
// headerClaimValueReason reports why value is unsafe to inject into a
// downstream request header, or "" if it is acceptable. It rejects empty
// values, values exceeding maxLen (maxLen<=0 disables the length check), and
// values containing any rune rejected by headerClaimRuneReason. The reason
// string is safe to log (it never includes the value itself).
func headerClaimValueReason(value string, maxLen int) string {
if value == "" {
return "empty value"
}
if maxLen > 0 && len(value) > maxLen {
return "exceeds max length"
}
for _, r := range value {
if reason := headerClaimRuneReason(r); reason != "" {
return reason
}
}
return ""
}
// sanitizeHeaderClaimValue validates a claim-derived value before it is
// injected into a downstream request header. It trims surrounding whitespace
// and fails closed (ok=false) on empty values, values exceeding maxLen
// (maxLen<=0 disables the length check), or values containing any rune rejected
// by headerClaimRuneReason. Used by the cookie/session path, which — unlike the
// bearer path — does not otherwise sanitize the principal identifier or the
// group/role strings joined into X-User-Groups / X-User-Roles.
func sanitizeHeaderClaimValue(raw string, maxLen int) (string, bool) {
value := strings.TrimSpace(raw)
if headerClaimValueReason(value, maxLen) != "" {
return "", false
}
return value, true
}
// sanitizeBearerIdentifier validates and trims a principal identifier before
// it is injected into request headers. Layered defense: net/http will reject
// CRLF on the wire too, but rejecting early gives clearer error logs and
@@ -251,8 +163,15 @@ func sanitizeBearerIdentifier(raw string, maxLen int) (string, *bearerError) {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier exceeds max length")
}
for _, r := range identifier {
if reason := headerClaimRuneReason(r); reason != "" {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains "+reason)
if unicode.IsControl(r) {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains control character")
}
// Unicode bidi-override range (RTL spoofing of admin UI / SIEM).
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains bidi-override character")
}
if r == ',' || r == ';' || r == '=' {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains delimiter character")
}
}
return identifier, nil
@@ -423,17 +342,7 @@ func (b *bearerFailureTracker) recordSuccess(ip string) {
}
b.mu.Lock()
defer b.mu.Unlock()
e, ok := b.entries[ip]
if !ok {
return
}
// Preserve an active penalty so a single success cannot wipe an in-effect
// lockout; only reset the counter when no penalty is active or it has expired.
now := time.Now()
if e.penaltyUntil.IsZero() || now.After(e.penaltyUntil) {
e.count = 0
e.firstFailureAt = now
}
delete(b.entries, ip)
}
// clientIPForBearer returns the source IP used to key the failure tracker.
+27 -45
View File
@@ -67,31 +67,31 @@ func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
t.Helper()
sm := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("error"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestStarted: 1,
next: next,
logger: NewLogger("error"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://issuer.example.com",
audience: "https://api.example.com",
clientID: "https://api.example.com",
tokenCache: NewTokenCache(),
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
allowedRolesAndGroups: map[string]struct{}{},
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
ctx: context.Background(),
enableBearerAuth: true,
stripAuthorizationHeader: true,
bearerEmitWWWAuthenticate: true,
bearerOverridesCookie: false,
bearerIdentifierClaim: "sub",
maxIdentifierLength: 256,
maxTokenAge: 24 * time.Hour,
bearerFailureThreshold: 20,
bearerFailureWindow: 60 * time.Second,
bearerFailurePenalty: 60 * time.Second,
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
issuerURL: "https://issuer.example.com",
audience: "https://api.example.com",
clientID: "https://api.example.com",
tokenCache: NewTokenCache(),
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
allowedRolesAndGroups: map[string]struct{}{},
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
ctx: context.Background(),
enableBearerAuth: true,
stripAuthorizationHeader: true,
bearerEmitWWWAuthenticate: true,
bearerOverridesCookie: false,
bearerIdentifierClaim: "sub",
maxIdentifierLength: 256,
maxTokenAge: 24 * time.Hour,
bearerFailureThreshold: 20,
bearerFailureWindow: 60 * time.Second,
bearerFailurePenalty: 60 * time.Second,
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
}
oidc.extractClaimsFunc = extractClaims
close(oidc.initComplete)
@@ -303,33 +303,15 @@ func TestBearerFailureTracker(t *testing.T) {
if b, retry := tr.blocked(ip); !b || retry <= 0 {
t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry)
}
// A success while a penalty is active must NOT wipe the in-effect lockout
// (otherwise a single success could clear an attacker's penalty).
// Success clears the counter.
tr.recordSuccess(ip)
if b, _ := tr.blocked(ip); !b {
t.Fatalf("expected still blocked after success while penalty active")
if b, _ := tr.blocked(ip); b {
t.Fatalf("expected unblocked after success")
}
// Other IPs are unaffected.
if b, _ := tr.blocked("10.0.0.2"); b {
t.Fatalf("unrelated IP should not be blocked")
}
// With an expired penalty, a success resets the counter so a subsequent
// sub-threshold failure does not immediately re-block.
tr2 := newBearerFailureTracker(3, 60*time.Second, 1*time.Millisecond)
const ip2 = "10.0.0.3"
for i := 0; i < 3; i++ {
tr2.recordFailure(ip2)
}
time.Sleep(5 * time.Millisecond) // let the short penalty expire
if b, _ := tr2.blocked(ip2); b {
t.Fatalf("expected unblocked after penalty expiry")
}
tr2.recordSuccess(ip2) // resets count since penalty has passed
tr2.recordFailure(ip2) // single failure, well below threshold
if b, _ := tr2.blocked(ip2); b {
t.Fatalf("expected unblocked: counter should have reset after success")
}
}
// =============================================================================
+2 -23
View File
@@ -16,9 +16,8 @@ type CacheManager struct {
}
var (
globalCacheManagerInstance *CacheManager
cacheManagerInitOnce sync.Once
cacheManagerActiveFingerprint string
globalCacheManagerInstance *CacheManager
cacheManagerInitOnce sync.Once
)
// GetGlobalCacheManager returns a singleton CacheManager instance.
@@ -30,9 +29,7 @@ func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
fp := redisFingerprint(config)
cacheManagerInitOnce.Do(func() {
cacheManagerActiveFingerprint = fp
var redisConfig *RedisConfig
var logger *Logger
@@ -58,27 +55,9 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
}
})
// Warn loudly if a later instance asks for a DIFFERENT explicit Redis
// backend than the one that won initialization: the cache manager is a
// process-global singleton shared across plugin instances (yaegi), so this
// instance's divergent configuration is silently ignored, which would
// otherwise collapse cache/state isolation between routes (rank 9).
if fp != "" && cacheManagerActiveFingerprint != "" && fp != cacheManagerActiveFingerprint {
NewLogger(config.LogLevel).Errorf("cache manager already initialized with Redis backend %q; this instance's Redis backend %q is IGNORED (process-global singleton). Use a single consistent cache configuration across all routes.", cacheManagerActiveFingerprint, fp)
}
return globalCacheManagerInstance
}
// redisFingerprint returns a stable identifier for an explicitly-enabled Redis
// backend (address + key prefix), or "" when Redis is not explicitly enabled.
// Used to detect divergent cache configurations across plugin instances.
func redisFingerprint(config *Config) string {
if config == nil || config.Redis == nil || !config.Redis.Enabled {
return ""
}
return config.Redis.Address + "|" + config.Redis.KeyPrefix
}
// GetSharedTokenBlacklist returns the shared token blacklist cache
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
cm.mu.RLock()
-46
View File
@@ -1,46 +0,0 @@
//go:build ignore
// Command yaegicheck verifies that the traefikoidc plugin can be imported and
// instantiated by the yaegi interpreter — the same way Traefik loads a plugin.
//
// It is run by `make yaegi-validate`. Importing the plugin package forces yaegi
// to interpret every source file in the package (and its vendored
// dependencies), so any construct yaegi cannot handle (unsupported stdlib
// symbol, reflection edge case, etc.) surfaces here rather than at Traefik load
// time. CreateConfig + New additionally exercise the instantiation path
// (session manager, cookie codec, caches, key derivation) under the interpreter.
package main
import (
"context"
"fmt"
"net/http"
"os"
oidc "github.com/lukaszraczylo/traefikoidc"
)
func main() {
cfg := oidc.CreateConfig()
cfg.ProviderURL = "https://accounts.google.com"
cfg.ClientID = "yaegi-check-client"
cfg.ClientSecret = "yaegi-check-secret"
cfg.CallbackURL = "/oauth2/callback"
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef"
cfg.RateLimit = 100
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h, err := oidc.New(context.Background(), next, cfg, "yaegi-check")
if err != nil {
fmt.Println("FAIL: New returned an error under yaegi:", err)
os.Exit(1)
}
if h == nil {
fmt.Println("FAIL: New returned a nil handler under yaegi")
os.Exit(1)
}
if closer, ok := h.(interface{ Close() error }); ok {
_ = closer.Close()
}
fmt.Println("OK: traefikoidc imported + CreateConfig + New succeeded under yaegi")
}
+76
View File
@@ -278,6 +278,82 @@ func TestHTTPClientProfiler_Methods_CoverageBoost(t *testing.T) {
}
}
// =============================================================================
// SECURITY MONITORING TESTS
// =============================================================================
func TestSecurityMonitor_StopCleanupRoutine_CoverageBoost(t *testing.T) {
logger := NewLogger("info")
config := SecurityMonitorConfig{
MaxFailuresPerIP: 5,
FailureWindowMinutes: 15,
BlockDurationMinutes: 30,
RapidFailureThreshold: 3,
CleanupIntervalMinutes: 60,
RetentionHours: 24,
EnablePatternDetection: true,
EnableDetailedLogging: false,
LogSuspiciousOnly: false,
}
sm := NewSecurityMonitor(config, logger)
if sm == nil {
t.Fatal("Expected non-nil SecurityMonitor")
}
// Start cleanup routine first (lowercase method)
sm.startCleanupRoutine()
// Give it a moment to start
time.Sleep(50 * time.Millisecond)
// Stop cleanup routine (public method)
sm.StopCleanupRoutine()
// Stop again should be safe
sm.StopCleanupRoutine()
}
func TestSecurityMonitor_MultipleHandlers_CoverageBoost(t *testing.T) {
logger := NewLogger("info")
config := SecurityMonitorConfig{
MaxFailuresPerIP: 5,
FailureWindowMinutes: 15,
BlockDurationMinutes: 30,
RapidFailureThreshold: 3,
CleanupIntervalMinutes: 60,
RetentionHours: 24,
}
sm := NewSecurityMonitor(config, logger)
// Create handler
handler := &LoggingSecurityEventHandler{logger: logger}
// Register handler using AddEventHandler
sm.AddEventHandler(handler)
// Record a failure to trigger events
sm.RecordAuthenticationFailure("192.168.1.100", "test-agent", "/test", "test_failure", nil)
}
func TestLoggingSecurityEventHandler_HandleSecurityEvent_AllSeverities_CoverageBoost(t *testing.T) {
logger := NewLogger("debug")
handler := &LoggingSecurityEventHandler{logger: logger}
// Severity is a string in this implementation
events := []SecurityEvent{
{Type: "test", Severity: "low", Message: "low severity"},
{Type: "test", Severity: "medium", Message: "medium severity"},
{Type: "test", Severity: "high", Message: "high severity"},
{Type: "test", Severity: "critical", Message: "critical severity"},
}
for _, event := range events {
handler.HandleSecurityEvent(event)
}
}
// =============================================================================
// SESSION MANAGER TESTS
// =============================================================================
+1 -1
View File
@@ -178,7 +178,7 @@ clientSecret: your-client-secret
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
| `rateLimit` | int | `100` | Maximum requests per second |
| `excludedURLs` | []string | none | Paths that bypass authentication, matched at a path-segment or file-extension boundary |
| `excludedURLs` | []string | none | Paths that bypass authentication |
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
+199
View File
@@ -370,6 +370,21 @@ func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, res
return r.saveCredentials(resp)
}
// deleteCredentialsFromStore removes credentials from the configured storage backend
// Falls back to legacy file-based deletion if no store is configured
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
// Use store if available
if r.store != nil {
return r.store.Delete(ctx, r.providerURL)
}
// Fallback to legacy file-based deletion
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// saveCredentials persists client credentials to a file (legacy method)
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
filePath := r.credentialsFilePath()
@@ -408,3 +423,187 @@ func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse,
return &resp, nil
}
// UpdateClientRegistration updates an existing client registration using RFC 7592
// This requires the registration_client_uri and registration_access_token from the original registration
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to update")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Build update request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build update request: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create update request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("update request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read update response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse update response: %w", err)
}
// Update cache
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist updated credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist updated credentials: %v", err)
}
}
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
return &regResp, nil
}
// ReadClientRegistration reads the current client registration using RFC 7592
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to read")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
if err != nil {
return nil, fmt.Errorf("failed to create read request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("read request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse read response: %w", err)
}
return &regResp, nil
}
// DeleteClientRegistration deletes the client registration using RFC 7592
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return fmt.Errorf("no existing registration to delete")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
if err != nil {
return fmt.Errorf("failed to create delete request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return fmt.Errorf("delete request failed: %w", err)
}
defer resp.Body.Close()
// Handle error responses (204 No Content is success)
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
}
// Clear cache
r.mu.Lock()
r.registrationResponse = nil
r.mu.Unlock()
// Remove credentials from storage if persistence is enabled
if r.config.PersistCredentials {
if err := r.deleteCredentialsFromStore(ctx); err != nil {
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
}
}
r.logger.Info("Successfully deleted client registration")
return nil
}
+252
View File
@@ -735,6 +735,258 @@ func TestDCRConfigDefaults(t *testing.T) {
}
}
// TestUpdateClientRegistration tests the RFC 7592 client update functionality
func TestUpdateClientRegistration(t *testing.T) {
updateCalled := false
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPut {
updateCalled = true
// Verify authorization header
if r.Header.Get("Authorization") == "" {
t.Error("Missing Authorization header for update")
}
resp := ClientRegistrationResponse{
ClientID: "updated-client-id",
ClientSecret: "updated-client-secret",
RegistrationAccessToken: "new-access-token",
RegistrationClientURI: r.URL.String(),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
}))
defer server.Close()
dcrConfig := &DynamicClientRegistrationConfig{
Enabled: true,
ClientMetadata: &ClientRegistrationMetadata{
RedirectURIs: []string{"https://example.com/callback"},
},
}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "original-client-id",
ClientSecret: "original-client-secret",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Perform update
ctx := context.Background()
resp, err := registrar.UpdateClientRegistration(ctx)
if err != nil {
t.Fatalf("Update failed: %v", err)
}
if !updateCalled {
t.Error("Update endpoint was not called")
}
if resp.ClientID != "updated-client-id" {
t.Errorf("Updated ClientID mismatch: got %s", resp.ClientID)
}
}
// TestDeleteClientRegistration tests the RFC 7592 client deletion functionality
func TestDeleteClientRegistration(t *testing.T) {
deleteCalled := false
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodDelete {
deleteCalled = true
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
tempDir := t.TempDir()
credentialsFile := filepath.Join(tempDir, "credentials.json")
// Create a credentials file to test deletion
os.WriteFile(credentialsFile, []byte(`{"client_id":"test"}`), 0600)
dcrConfig := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
CredentialsFile: credentialsFile,
}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "test-client-id",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Perform delete
ctx := context.Background()
err := registrar.DeleteClientRegistration(ctx)
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
if !deleteCalled {
t.Error("Delete endpoint was not called")
}
// Verify cache is cleared
if registrar.GetCachedResponse() != nil {
t.Error("Cached response should be cleared after deletion")
}
// Verify credentials file is deleted
if _, err := os.Stat(credentialsFile); !os.IsNotExist(err) {
t.Error("Credentials file should be deleted")
}
}
// TestReadClientRegistration tests the RFC 7592 client read functionality
func TestReadClientRegistration(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
resp := ClientRegistrationResponse{
ClientID: "read-client-id",
ClientSecret: "read-client-secret",
RedirectURIs: []string{"https://example.com/callback"},
ResponseTypes: []string{"code"},
GrantTypes: []string{"authorization_code"},
ApplicationType: "web",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
}))
defer server.Close()
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "original-client-id",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Read registration
ctx := context.Background()
resp, err := registrar.ReadClientRegistration(ctx)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if resp.ClientID != "read-client-id" {
t.Errorf("Read ClientID mismatch: got %s", resp.ClientID)
}
}
// TestOperationsWithoutCachedResponse tests error handling when no cached response exists
func TestOperationsWithoutCachedResponse(t *testing.T) {
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
&http.Client{},
NewLogger("DEBUG"),
dcrConfig,
"https://example.com",
)
ctx := context.Background()
// Test Update without cached response
_, err := registrar.UpdateClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Update should fail without cached response: %v", err)
}
// Test Read without cached response
_, err = registrar.ReadClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Read should fail without cached response: %v", err)
}
// Test Delete without cached response
err = registrar.DeleteClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Delete should fail without cached response: %v", err)
}
}
// TestOperationsWithoutManagementCredentials tests error handling without management URIs
func TestOperationsWithoutManagementCredentials(t *testing.T) {
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
&http.Client{},
NewLogger("DEBUG"),
dcrConfig,
"https://example.com",
)
// Set up cached response WITHOUT management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "test-client-id",
// Missing RegistrationAccessToken and RegistrationClientURI
}
registrar.mu.Unlock()
ctx := context.Background()
// Test Update without management credentials
_, err := registrar.UpdateClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Update should fail without management credentials: %v", err)
}
// Test Read without management credentials
_, err = registrar.ReadClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Read should fail without management credentials: %v", err)
}
// Test Delete without management credentials
err = registrar.DeleteClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Delete should fail without management credentials: %v", err)
}
}
// stringContains is a helper function to check if a string contains a substring
func stringContains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContainsHelper(s, substr))
+5 -6
View File
@@ -539,10 +539,10 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
return true
}
errStr := strings.ToLower(err.Error())
errStr := err.Error()
for _, retryableErr := range re.config.RetryableErrors {
if contains(errStr, strings.ToLower(retryableErr)) {
if contains(errStr, retryableErr) {
return true
}
}
@@ -551,7 +551,7 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
if netErr.Timeout() {
return true
}
errStr := strings.ToLower(netErr.Error())
errStr := netErr.Error()
temporaryPatterns := []string{
"connection refused",
"connection reset",
@@ -859,9 +859,8 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f
// isServiceDegraded checks if a service is currently degraded
func (gd *GracefulDegradation) isServiceDegraded(serviceName string) bool {
// Uses a write lock because the recovery-timeout branch deletes from the map.
gd.mutex.Lock()
defer gd.mutex.Unlock()
gd.mutex.RLock()
defer gd.mutex.RUnlock()
degradedTime, exists := gd.degradedServices[serviceName]
if !exists {
-1
View File
@@ -5,7 +5,6 @@ go 1.24.0
require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/gorilla/sessions v1.3.0
github.com/lukaszraczylo/oss-telemetry v0.2.3
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.14.0
-2
View File
@@ -16,8 +16,6 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/lukaszraczylo/oss-telemetry v0.2.3 h1:xoDtBqeZGmXj7IteiE1M5WMuzeoqag58qEleI0Cf2Ms=
github.com/lukaszraczylo/oss-telemetry v0.2.3/go.mod h1:+Cn78qZo8rc3T9eZt0v3oICYRdd75wORtSidc8lNjDQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
+1 -10
View File
@@ -392,19 +392,10 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
// localRedirect is used when there is no provider end-session endpoint and
// the plugin redirects the browser itself. It must never be an absolute URL
// derived from the request host (X-Forwarded-Host is client-controllable and
// would be an open redirect); use a host-relative path, or the operator's
// own configured absolute URL, instead.
localRedirect := "/"
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
localRedirect = normalizeLogoutPath(postLogoutRedirectURI)
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
} else {
localRedirect = postLogoutRedirectURI
}
// Read endSessionURL with RLock
@@ -423,7 +414,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
http.Redirect(rw, req, localRedirect, http.StatusFound)
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs a logout URL for the OIDC provider's end session endpoint.
+6 -30
View File
@@ -26,10 +26,6 @@ type sharedTransport struct {
lastUsed time.Time
transport *http.Transport
refCount int
// tlsKey identifies the TLS trust settings (CA pool + InsecureSkipVerify)
// this transport was built with, so the at-limit fallback only reuses a
// transport whose TLS configuration matches the caller's.
tlsKey string
}
var (
@@ -57,26 +53,19 @@ func GetGlobalTransportPool() *SharedTransportPool {
// GetOrCreateTransport gets or creates a shared transport with the given config
func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *http.Transport {
// SECURITY FIX: Check client limit before creating new transport.
// SECURITY FIX: Check client limit before creating new transport
if atomic.LoadInt32(&p.clientCount) >= p.maxClients {
// At the client limit: only reuse a transport that was built for the
// SAME config (same TLS trust store). refCount is mutated under the
// write lock to avoid a data race, and a transport created for a
// different configuration is never handed back — doing so could apply
// the wrong (possibly verification-disabled) TLS settings to a request.
want := tlsConfigKey(config)
p.mu.Lock()
defer p.mu.Unlock()
// Return existing transport if limit reached
p.mu.RLock()
defer p.mu.RUnlock()
for _, shared := range p.transports {
if shared != nil && shared.transport != nil && shared.tlsKey == want {
if shared != nil && shared.transport != nil {
shared.refCount++
shared.lastUsed = time.Now()
return shared.transport
}
}
// No TLS-compatible transport available; return nil so the caller falls
// back to a default, certificate-verifying transport rather than one
// with a different (possibly verification-disabled) trust store.
// If no transport available, return nil (caller should handle)
return nil
}
@@ -136,7 +125,6 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
transport: transport,
refCount: 1,
lastUsed: time.Now(),
tlsKey: tlsConfigKey(config),
}
return transport
@@ -236,18 +224,6 @@ func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
)
}
// tlsConfigKey identifies only the TLS trust settings of a config — the CA pool
// and the InsecureSkipVerify flag. Two configs with the same tlsConfigKey are
// safe to serve from the same transport even if other (non-TLS) parameters such
// as connection limits differ; configs with different TLS settings are not.
func tlsConfigKey(config HTTPClientConfig) string {
skip := "0"
if config.InsecureSkipVerify {
skip = "1"
}
return fmt.Sprintf("%p|%s", config.RootCAs, skip)
}
// Cleanup closes all transports and stops the cleanup goroutine
func (p *SharedTransportPool) Cleanup() {
p.mu.Lock()
+4 -12
View File
@@ -842,18 +842,10 @@ func TestWorkerPool_TaskPanic(t *testing.T) {
t.Error("Timeout waiting for tasks")
}
// tasksFailed is incremented in the worker's deferred recover(), which runs
// AFTER the panicking task's own `defer wg.Done()`. wg.Wait() above can
// therefore return before the failure is recorded — reading the counter
// immediately is a race that flakes on slow/contended CI runners. Poll until
// the failure lands (or time out).
deadline := time.Now().Add(2 * time.Second)
for pool.GetMetrics()["tasksFailed"].(int64) < 1 {
if time.Now().After(deadline) {
t.Error("Expected at least one failed task")
break
}
time.Sleep(5 * time.Millisecond)
// Pool should still be functional
metrics := pool.GetMetrics()
if metrics["tasksFailed"].(int64) < 1 {
t.Error("Expected at least one failed task")
}
}
+1 -23
View File
@@ -155,34 +155,12 @@ func DetermineScheme(req *http.Request, forceHTTPS bool) string {
// It checks X-Forwarded-Host header first (for proxy scenarios),
// then falls back to req.Host.
func DetermineHost(req *http.Request) string {
if host := sanitizeForwardedHost(req.Header.Get("X-Forwarded-Host")); host != "" {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
}
return req.Host
}
// sanitizeForwardedHost returns a single, well-formed host from a (possibly
// comma-separated) X-Forwarded-Host header, or "" if none is usable. It takes
// only the first value and rejects whitespace and control characters, so a
// crafted header cannot inject CRLF, smuggle a second host, or otherwise poison
// the redirect URLs built from the result.
func sanitizeForwardedHost(v string) string {
if v == "" {
return ""
}
if i := strings.IndexByte(v, ','); i >= 0 {
v = v[:i]
}
v = strings.TrimSpace(v)
if v == "" {
return ""
}
if strings.IndexFunc(v, func(r rune) bool { return r < 0x20 || r == 0x7f || r == ' ' }) >= 0 {
return ""
}
return v
}
// BuildFullURL constructs a URL from scheme, host, and path components.
// It handles absolute URLs (returning them as-is) and ensures paths have leading slashes.
func BuildFullURL(scheme, host, path string) string {
+2 -18
View File
@@ -200,22 +200,6 @@ func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
if k.Kid == "" {
continue
}
// Skip keys that are not intended for signature verification.
if k.Use != "" && k.Use != "sig" {
continue
}
if len(k.KeyOps) > 0 {
hasVerify := false
for _, op := range k.KeyOps {
if op == "verify" {
hasVerify = true
break
}
}
if !hasVerify {
continue
}
}
var pub crypto.PublicKey
var err error
switch k.Kty {
@@ -258,11 +242,11 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024)) // Safe to ignore: reading error body for diagnostics
body, _ := io.ReadAll(resp.Body) // Safe to ignore: reading error body for diagnostics
return nil, fmt.Errorf("JWKS fetch failed with status %d: %s", resp.StatusCode, body)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading JWKS response: %w", err)
}
+2 -5
View File
@@ -134,11 +134,8 @@ func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
// Require a matching issuer. An empty iss must be rejected too: accepting a
// missing issuer would let an unauthenticated attacker force-logout any
// session whose sid is known by simply omitting iss.
if iss == "" || iss != expectedIssuer {
t.logger.Errorf("Front-channel logout: issuer validation failed: got %q, expected %q", iss, expectedIssuer)
if iss != "" && iss != expectedIssuer {
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
return
}
+25 -32
View File
@@ -125,14 +125,10 @@ func TestFrontchannelLogoutBasic(t *testing.T) {
expectedStatus: http.StatusOK,
},
{
// Front-channel logout MUST carry a matching issuer. A request
// omitting iss is rejected so an unauthenticated attacker cannot
// force-logout a session whose sid is known by simply leaving iss
// out (audit rank 30).
name: "Missing issuer is rejected",
name: "Valid front-channel logout without issuer",
method: http.MethodGet,
queryParams: map[string]string{"sid": "session456"},
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusOK,
},
}
@@ -411,17 +407,17 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) {
})
oidc := &TraefikOidc{
next: nextHandler,
logger: NewLogger("debug"),
enableBackchannelLogout: true,
backchannelLogoutPath: "/backchannel-logout",
sessionInvalidationCache: mockCache,
clientID: "test-client",
issuerURL: "https://provider.example.com",
initComplete: make(chan struct{}),
firstRequestStarted: 1,
next: nextHandler,
logger: NewLogger("debug"),
enableBackchannelLogout: true,
backchannelLogoutPath: "/backchannel-logout",
sessionInvalidationCache: mockCache,
clientID: "test-client",
issuerURL: "https://provider.example.com",
initComplete: make(chan struct{}),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
logoutURLPath: "/logout",
logoutURLPath: "/logout",
}
close(oidc.initComplete)
@@ -453,23 +449,22 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
})
oidc := &TraefikOidc{
next: nextHandler,
logger: NewLogger("debug"),
enableFrontchannelLogout: true,
frontchannelLogoutPath: "/frontchannel-logout",
sessionInvalidationCache: mockCache,
clientID: "test-client",
issuerURL: "https://provider.example.com",
initComplete: make(chan struct{}),
firstRequestStarted: 1,
next: nextHandler,
logger: NewLogger("debug"),
enableFrontchannelLogout: true,
frontchannelLogoutPath: "/frontchannel-logout",
sessionInvalidationCache: mockCache,
clientID: "test-client",
issuerURL: "https://provider.example.com",
initComplete: make(chan struct{}),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
logoutURLPath: "/logout",
logoutURLPath: "/logout",
}
close(oidc.initComplete)
// Request to front-channel logout path with valid sid + matching issuer
// should succeed. The issuer is now required (audit rank 30), so supply it.
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session&iss=https://provider.example.com", nil)
// Request to front-channel logout path with valid sid should succeed
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session", nil)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
@@ -1437,9 +1432,7 @@ func TestFrontchannelLogoutCacheControl(t *testing.T) {
issuerURL: "https://provider.example.com",
}
// Issuer is now required (audit rank 30); supply a matching one so the
// successful-logout cache headers can be asserted.
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123&iss=https://provider.example.com", nil)
req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123", nil)
rw := httptest.NewRecorder()
oidc.handleFrontchannelLogout(rw, req)
+19 -79
View File
@@ -9,7 +9,6 @@ import (
"encoding/hex"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"strings"
@@ -17,7 +16,6 @@ import (
"text/template"
"time"
telemetry "github.com/lukaszraczylo/oss-telemetry"
"golang.org/x/time/rate"
)
@@ -25,11 +23,6 @@ const (
ConstSessionTimeout = 86400
)
// telemetryStartupOnce keeps the anonymous "plugin loaded" ping to one per
// process. Traefik calls New once per route that uses the plugin; oss-telemetry
// does not deduplicate client-side (the server does), so the gate stays here.
var telemetryStartupOnce sync.Once
// isTestMode detects if the code is running in a test environment.
func isTestMode() bool {
if os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS") == "1" {
@@ -96,13 +89,7 @@ var defaultExcludedURLs = map[string]struct{}{
// - The configured TraefikOidc handler ready to process requests.
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
telemetryStartupOnce.Do(func() {
// Only stamped release builds phone home; dev/local/test builds keep the
// devPluginVersion sentinel (see version.go) and stay silent.
if traefikoidcPluginVersion != devPluginVersion {
telemetry.Send("traefikoidc", traefikoidcPluginVersion)
}
})
sendTelemetry(pluginVersion)
return NewWithContext(ctx, config, next, name)
}
@@ -113,18 +100,18 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
config = CreateConfig()
}
logger := NewLogger(config.LogLevel)
if config.SessionEncryptionKey == "" {
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
// Fail closed on invalid configuration. Validate() enforces the security
// constraints (required fields, HTTPS-only URLs, key length, excludedURLs
// safety, rate-limit floor, audience format, ...) that were previously
// unenforced because this constructor never called it. Crucially it rejects
// an empty or too-short SessionEncryptionKey instead of silently
// substituting a public hardcoded key, which would let an attacker forge
// any session. Traefik's yaegi plugin analyzer supplies a valid key via
// .traefik.yml testData, so it passes; only misconfigured deployments fail.
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
logger := NewLogger(config.LogLevel)
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
if runtime.Compiler == "yaegi" {
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
logger.Infof("Session encryption key is too short; using default key for analyzer")
} else {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
}
// Setup HTTP client
caPool, err := config.loadCACertPool()
@@ -236,7 +223,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
httpClient: httpClient,
tokenHTTPClient: tokenHTTPClient,
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createCaseInsensitiveStringMap(config.AllowedUserDomains),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
@@ -347,12 +334,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
// Convert sessionMaxAge from seconds to duration (0 will use default 24 hours)
sessionMaxAge := time.Duration(config.SessionMaxAge) * time.Second
sessionManager, err := NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger)
if err != nil {
cancelFunc()
return nil, fmt.Errorf("failed to create session manager: %w", err)
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger) // Safe to ignore: session manager creation with fallback to defaults
if config.CookiePath != "" {
t.sessionManager.cookiePath = config.CookiePath
t.logger.Debugf("Using configured cookie path: %s", config.CookiePath)
}
t.sessionManager = sessionManager
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
// Initialize token resilience manager with default configuration
@@ -443,7 +429,6 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
// Add reference for this instance
rm.AddReference(name)
registerLiveInstance()
// Initialize metadata in a goroutine with proper tracking
if t.goroutineWG != nil {
@@ -521,58 +506,13 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
// Parameters:
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
// SSRF defense (audit ranks 3 & 4): a discovery document is attacker-
// influenced when the provider or its TLS is compromised. Reject any
// discovered endpoint pointed at a blocked address before the plugin issues
// outbound requests to it, so it can never be used to reach the cloud
// metadata service or an internal host.
allowLoopback := false
if pu, err := url.Parse(t.providerURL); err == nil {
allowLoopback = isLoopbackHost(pu.Hostname())
}
sanitize := func(name, raw string) string {
if err := t.validateDiscoveredEndpoint(raw, allowLoopback); err != nil {
t.logger.Errorf("Ignoring discovered %s endpoint %q: %v", name, raw, err)
return ""
}
return raw
}
metadata.JWKSURL = sanitize("jwks_uri", metadata.JWKSURL)
metadata.AuthURL = sanitize("authorization", metadata.AuthURL)
metadata.TokenURL = sanitize("token", metadata.TokenURL)
metadata.RevokeURL = sanitize("revocation", metadata.RevokeURL)
metadata.EndSessionURL = sanitize("end_session", metadata.EndSessionURL)
metadata.RegistrationURL = sanitize("registration", metadata.RegistrationURL)
metadata.IntrospectionURL = sanitize("introspection", metadata.IntrospectionURL)
// The introspection request authenticates with the client secret via HTTP
// Basic, so the endpoint must live on the same host as the operator-
// configured provider; otherwise a poisoned discovery document could
// exfiltrate the client secret to an attacker-controlled host.
if metadata.IntrospectionURL != "" && t.providerURL != "" && !sameHost(metadata.IntrospectionURL, t.providerURL) {
t.logger.Errorf("Ignoring introspection endpoint %q: host does not match configured providerURL", metadata.IntrospectionURL)
metadata.IntrospectionURL = ""
}
// Pin the discovered issuer to the operator-configured provider host. The
// issuer is the trust anchor for JWT issuer validation, so a poisoned
// discovery document advertising an attacker-chosen issuer must never be
// stored. Real providers (Google, Azure, Keycloak, Okta, Auth0) keep the
// issuer on the same host as the configured providerURL. On mismatch, leave
// issuerURL empty/unchanged so downstream issuer validation fails closed
// rather than trusting the attacker-chosen value.
discoveredIssuer := metadata.Issuer
if discoveredIssuer != "" && t.providerURL != "" && !sameHost(discoveredIssuer, t.providerURL) {
t.logger.Errorf("Ignoring discovered issuer %q: host does not match configured providerURL", discoveredIssuer)
discoveredIssuer = ""
}
t.metadataMu.Lock()
t.jwksURL = metadata.JWKSURL
t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = discoveredIssuer
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662)
@@ -585,7 +525,7 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
// Publish the read-mostly URL bundle atomically. Hot-path readers Load
// this directly instead of acquiring metadataMu.RLock per request.
t.metadataSnapshot.Store(&MetadataSnapshot{
IssuerURL: discoveredIssuer,
IssuerURL: metadata.Issuer,
JWKSURL: metadata.JWKSURL,
TokenURL: metadata.TokenURL,
AuthURL: metadata.AuthURL,
-2
View File
@@ -194,7 +194,6 @@ func TestGoroutineLeakPrevention_MultipleInstances(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
handler, err := New(ctx, nil, config, "test")
if err != nil {
@@ -323,7 +322,6 @@ func TestGoroutineLeakPrevention_BackgroundTaskCleanup(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
handler, err := New(ctx, nil, config, "test")
if err != nil {
+38 -56
View File
@@ -26,47 +26,38 @@ func TestInitializeMetadata(t *testing.T) {
name: "successful metadata initialization",
providerURL: "",
setupMock: func() *httptest.Server {
// Issuer must share the host with providerURL (the httptest
// server), otherwise the discovery doc is rejected as poisoned
// (audit ranks 21/22). Real providers keep issuer + endpoints on
// the same host, so derive them all from the server URL.
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: srv.URL,
AuthURL: srv.URL + "/auth",
TokenURL: srv.URL + "/token",
JWKSURL: srv.URL + "/jwks",
RevokeURL: srv.URL + "/revoke",
EndSessionURL: srv.URL + "/logout",
Issuer: "https://provider.example.com",
AuthURL: "https://provider.example.com/auth",
TokenURL: "https://provider.example.com/token",
JWKSURL: "https://provider.example.com/jwks",
RevokeURL: "https://provider.example.com/revoke",
EndSessionURL: "https://provider.example.com/logout",
})
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
return srv
},
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
if oidc.authURL != "https://provider.example.com/auth" {
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
}
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
if oidc.tokenURL != "https://provider.example.com/token" {
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
}
if oidc.jwksURL == "" || !strings.HasSuffix(oidc.jwksURL, "/jwks") {
if oidc.jwksURL != "https://provider.example.com/jwks" {
t.Errorf("expected jwksURL to be set, got %s", oidc.jwksURL)
}
if oidc.revocationURL == "" || !strings.HasSuffix(oidc.revocationURL, "/revoke") {
if oidc.revocationURL != "https://provider.example.com/revoke" {
t.Errorf("expected revocationURL to be set, got %s", oidc.revocationURL)
}
if oidc.endSessionURL == "" || !strings.HasSuffix(oidc.endSessionURL, "/logout") {
if oidc.endSessionURL != "https://provider.example.com/logout" {
t.Errorf("expected endSessionURL to be set, got %s", oidc.endSessionURL)
}
if oidc.issuerURL == "" {
t.Errorf("expected issuerURL to be pinned to provider host, got empty")
}
},
wantPanic: false,
},
@@ -125,27 +116,24 @@ func TestInitializeMetadata(t *testing.T) {
name: "partial metadata response",
providerURL: "",
setupMock: func() *httptest.Server {
// Issuer host must match providerURL (audit ranks 21/22).
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Only return some fields
json.NewEncoder(w).Encode(map[string]string{
"issuer": srv.URL,
"authorization_endpoint": srv.URL + "/auth",
"token_endpoint": srv.URL + "/token",
"issuer": "https://partial.example.com",
"authorization_endpoint": "https://partial.example.com/auth",
"token_endpoint": "https://partial.example.com/token",
// Missing jwks_uri, revocation_endpoint, end_session_endpoint
})
}
}))
return srv
},
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
if oidc.authURL != "https://partial.example.com/auth" {
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
}
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
if oidc.tokenURL != "https://partial.example.com/token" {
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
}
// JWKS URL and others may be empty
@@ -210,22 +198,20 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
requestCount := 0
var mu sync.Mutex
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
requestCount++
mu.Unlock()
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
RevokeURL: server.URL + "/revoke",
EndSessionURL: server.URL + "/logout",
Issuer: "https://concurrent.example.com",
AuthURL: "https://concurrent.example.com/auth",
TokenURL: "https://concurrent.example.com/token",
JWKSURL: "https://concurrent.example.com/jwks",
RevokeURL: "https://concurrent.example.com/revoke",
EndSessionURL: "https://concurrent.example.com/logout",
})
}
}))
@@ -264,7 +250,7 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
oidc.initializeMetadata(server.URL)
// Verify initialization
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
if oidc.tokenURL != "https://concurrent.example.com/token" {
t.Errorf("expected tokenURL to be set")
}
}()
@@ -356,19 +342,17 @@ func TestProviderDetection(t *testing.T) {
// TestInitializationWaiting tests waiting for initialization to complete
func TestInitializationWaiting(t *testing.T) {
t.Run("wait for initialization completion", func(t *testing.T) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Delay response to simulate slow initialization
time.Sleep(100 * time.Millisecond)
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
Issuer: "https://slow.example.com",
AuthURL: "https://slow.example.com/auth",
TokenURL: "https://slow.example.com/token",
JWKSURL: "https://slow.example.com/jwks",
})
}
}))
@@ -405,7 +389,7 @@ func TestInitializationWaiting(t *testing.T) {
select {
case <-oidc.initComplete:
// Success
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
if oidc.tokenURL != "https://slow.example.com/token" {
t.Error("expected tokenURL to be set after initialization")
}
case <-time.After(2 * time.Second):
@@ -414,19 +398,17 @@ func TestInitializationWaiting(t *testing.T) {
})
t.Run("multiple waiters for initialization", func(t *testing.T) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Delay to ensure multiple waiters
time.Sleep(50 * time.Millisecond)
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
Issuer: "https://multi.example.com",
AuthURL: "https://multi.example.com/auth",
TokenURL: "https://multi.example.com/token",
JWKSURL: "https://multi.example.com/jwks",
})
}
}))
@@ -471,7 +453,7 @@ func TestInitializationWaiting(t *testing.T) {
select {
case <-oidc.initComplete:
// All waiters should see the same initialized state
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
if oidc.tokenURL != "https://multi.example.com/token" {
t.Errorf("waiter %d: expected tokenURL to be set", id)
}
case <-time.After(2 * time.Second):
+44 -68
View File
@@ -1875,14 +1875,14 @@ func TestHandleLogout(t *testing.T) {
},
endSessionURL: "",
expectedStatus: http.StatusFound,
expectedURL: "/",
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with empty session",
setupSession: func(session *SessionData) {},
expectedStatus: http.StatusFound,
expectedURL: "/",
expectedURL: "http://example.com/",
host: "test-host",
},
{
@@ -2349,22 +2349,19 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
t.Skip("Skipping test in short mode")
}
// Create mock provider metadata server. Issuer + endpoints must share the
// host with ProviderURL (the httptest server), otherwise the discovery doc
// is rejected as poisoned (audit ranks 21/22). Derive them from the server.
var mockServer *httptest.Server
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create mock provider metadata server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
metadata := ProviderMetadata{
Issuer: mockServer.URL,
AuthURL: mockServer.URL + "/auth",
TokenURL: mockServer.URL + "/token",
JWKSURL: mockServer.URL + "/jwks",
RevokeURL: mockServer.URL + "/revoke",
EndSessionURL: mockServer.URL + "/end-session",
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
RevokeURL: "https://test-issuer.com/revoke",
EndSessionURL: "https://test-issuer.com/end-session",
}
json.NewEncoder(w).Encode(metadata)
}))
@@ -2377,7 +2374,6 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
RateLimit: 100,
}
// Create multiple middleware instances
@@ -2418,20 +2414,18 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
t.Fatalf("Middleware instance %d failed to initialize", i)
}
// Verify each instance has its own unique configuration. Issuer is now
// pinned to the provider host (audit ranks 21/22), so it equals the
// mock server URL rather than a fixed literal.
if m.issuerURL != mockServer.URL {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, mockServer.URL, m.issuerURL)
// Verify each instance has its own unique configuration
if m.issuerURL != "https://test-issuer.com" {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
}
if m.authURL != mockServer.URL+"/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, mockServer.URL+"/auth", m.authURL)
if m.authURL != "https://test-issuer.com/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
}
if m.tokenURL != mockServer.URL+"/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, mockServer.URL+"/token", m.tokenURL)
if m.tokenURL != "https://test-issuer.com/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
}
if m.jwksURL != mockServer.URL+"/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, mockServer.URL+"/jwks", m.jwksURL)
if m.jwksURL != "https://test-issuer.com/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
}
if m.redirURLPath != routes[i]+"/callback" {
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
@@ -2445,16 +2439,15 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
m.ServeHTTP(rr, req)
// Should redirect (302) to the auth flow since not authenticated. The
// absolute auth URL is not asserted here: with issuer pinning (audit
// ranks 21/22) the discovery host equals the httptest server host,
// which is loopback, so buildAuthURL's SSRF guard legitimately refuses
// to emit a loopback authorization URL in this test environment. The
// per-instance auth/token/jwks/issuer URLs were already verified above;
// here we only confirm each instance independently triggers a redirect.
// Should redirect to auth URL since not authenticated
if rr.Code != http.StatusFound {
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
}
location := rr.Header().Get("Location")
if !strings.Contains(location, "https://test-issuer.com/auth") {
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
}
}
}
@@ -2467,43 +2460,33 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
}
// Create two mock provider metadata servers simulating different Keycloak realms
// Issuer + endpoints must share the host with each realm's ProviderURL
// (the httptest server), otherwise the discovery doc is rejected as
// poisoned (audit ranks 21/22). Keep the distinguishing /realms/realmN
// path so the per-realm isolation assertions below still hold, but base
// the host on the server URL — which is exactly what a same-host Keycloak
// deployment looks like.
var realm1Server *httptest.Server
realm1Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
base := realm1Server.URL + "/realms/realm1"
metadata := ProviderMetadata{
Issuer: base,
AuthURL: base + "/protocol/openid-connect/auth",
TokenURL: base + "/protocol/openid-connect/token",
JWKSURL: base + "/protocol/openid-connect/certs",
EndSessionURL: base + "/protocol/openid-connect/logout",
Issuer: "https://keycloak.example.com/realms/realm1",
AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth",
TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token",
JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs",
EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout",
}
json.NewEncoder(w).Encode(metadata)
}))
defer realm1Server.Close()
var realm2Server *httptest.Server
realm2Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
base := realm2Server.URL + "/realms/realm2"
metadata := ProviderMetadata{
Issuer: base,
AuthURL: base + "/protocol/openid-connect/auth",
TokenURL: base + "/protocol/openid-connect/token",
JWKSURL: base + "/protocol/openid-connect/certs",
EndSessionURL: base + "/protocol/openid-connect/logout",
Issuer: "https://keycloak.example.com/realms/realm2",
AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth",
TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token",
JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs",
EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout",
}
json.NewEncoder(w).Encode(metadata)
}))
@@ -2517,7 +2500,6 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
CallbackURL: "/realm1/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
CookiePrefix: "_oidc_realm1_",
RateLimit: 100,
}
// Config for realm2
@@ -2528,7 +2510,6 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
CallbackURL: "/realm2/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
CookiePrefix: "_oidc_realm2_",
RateLimit: 100,
}
// Create middleware instances for both realms
@@ -2627,11 +2608,8 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
providerAvailable := false
var mu sync.Mutex
// Create mock provider that initially fails, then becomes available.
// Issuer + endpoints must share the host with ProviderURL (audit ranks
// 21/22), so derive them from the server URL.
var mockServer *httptest.Server
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create mock provider that initially fails, then becomes available
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
available := providerAvailable
mu.Unlock()
@@ -2643,11 +2621,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
if r.URL.Path == "/.well-known/openid-configuration" {
metadata := ProviderMetadata{
Issuer: mockServer.URL,
AuthURL: mockServer.URL + "/auth",
TokenURL: mockServer.URL + "/token",
JWKSURL: mockServer.URL + "/jwks",
EndSessionURL: mockServer.URL + "/logout",
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
EndSessionURL: "https://test-issuer.com/logout",
}
json.NewEncoder(w).Encode(metadata)
return
@@ -2662,7 +2640,6 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
RateLimit: 100,
}
// Create middleware while provider is unavailable
@@ -4575,7 +4552,6 @@ func TestNewWithScopeAppending(t *testing.T) {
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
Scopes: tc.configScopes,
RateLimit: 100,
}
// Create middleware instance
-1
View File
@@ -1652,7 +1652,6 @@ func TestGoroutineLeaks(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client"
config.ClientSecret = "test-secret"
config.CallbackURL = "/callback"
handler, err := New(context.Background(), nil, config, "test")
require.NoError(t, err)
+1 -11
View File
@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
@@ -142,19 +141,10 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
}
var metadata ProviderMetadata
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&metadata); err != nil {
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
return nil, fmt.Errorf("failed to decode metadata: %w", err)
}
// Pin the advertised issuer to the configured provider host. The issuer is
// the trust anchor for JWT issuer validation; rejecting a mismatch here
// ensures a poisoned discovery document advertising an attacker-chosen
// issuer is never cached or returned. Real providers (Google, Azure,
// Keycloak, Okta, Auth0) keep the issuer on the same host as providerURL.
if metadata.Issuer != "" && !sameHost(metadata.Issuer, providerURL) {
return nil, fmt.Errorf("discovery issuer %q host does not match provider %q", metadata.Issuer, providerURL)
}
// Cache for 1 hour by default
if err := mc.Set(providerURL, &metadata, 1*time.Hour); err != nil {
mc.logger.Errorf("Failed to cache metadata: %v", err)
+7 -80
View File
@@ -472,7 +472,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// - req: The HTTP request to process.
// - session: The user's session data containing tokens and claims.
// - redirectURL: The callback URL for re-authentication if needed.
//
// processAuthorizedRequestRS is the requestState-aware variant of
// processAuthorizedRequest. It reads SessionData fields from the captured
// snapshot in rs instead of calling session.GetX() (each of which acquires
@@ -676,44 +675,6 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
//
// Session persistence is the CALLER's responsibility — it must happen before
// this function so Set-Cookie reaches the response.
// headerTemplateMaxLen bounds the length of a rendered operator-defined header
// template before it is forwarded downstream. Generous enough for an
// "Authorization: Bearer <jwt>" value but small enough to reject obviously
// abusive output. Matches the input-validation default header cap (8KB).
const headerTemplateMaxLen = 8192
// headerClaimMaxLen returns the maximum accepted length for a claim-derived
// header value (principal identifier, group, role). Reuses the operator-
// configured identifier cap (default 256) so a single setting governs both
// auth paths; falls back to 256 when unset.
func (t *TraefikOidc) headerClaimMaxLen() int {
if t.maxIdentifierLength > 0 {
return t.maxIdentifierLength
}
return 256
}
// sanitizeHeaderClaimList drops any group/role value that fails claim
// sanitization (control chars, bidi-override runes, the , ; = delimiters, or an
// over-long value) and returns the surviving values. Failing closed on a bad
// entry prevents header injection and stops an embedded comma from injecting
// extra entries into the comma-joined header. headerName is used only for
// debug logging — the value is never logged.
func (t *TraefikOidc) sanitizeHeaderClaimList(values []string, headerName string) []string {
if len(values) == 0 {
return nil
}
safe := make([]string, 0, len(values))
for _, v := range values {
if clean, ok := sanitizeHeaderClaimValue(v, t.headerClaimMaxLen()); ok {
safe = append(safe, clean)
} else {
t.logger.Debugf("Dropping %s entry: value failed claim sanitization", headerName)
}
}
return safe
}
func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Request, p *principal) {
var (
groups, roles []string
@@ -731,18 +692,11 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
return
}
if extractErr == nil {
// Sanitize each group/role before it is joined into a comma-
// delimited header. The cookie/session path does not otherwise
// sanitize claim-derived values (the bearer path sanitizes its
// identifier at construction), so a control char would enable
// header injection and an embedded comma would inject extra
// entries into the comma-joined header. Fail closed: drop any
// value that does not pass.
if safeGroups := t.sanitizeHeaderClaimList(groups, "X-User-Groups"); len(safeGroups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(safeGroups, ","))
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if safeRoles := t.sanitizeHeaderClaimList(roles, "X-User-Roles"); len(safeRoles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(safeRoles, ","))
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
}
@@ -763,26 +717,12 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
}
}
// Sanitize the principal identifier before injecting it into headers. The
// bearer path already sanitizes its identifier at construction; the
// cookie/session path does not, so a claim carrying control chars, bidi-
// override runes, or , ; = could inject or spoof header content. Fail
// closed: drop the identifier header(s) rather than forward a tainted value.
safeIdentifier, identifierOK := sanitizeHeaderClaimValue(p.Identifier, t.headerClaimMaxLen())
if identifierOK {
req.Header.Set("X-Forwarded-User", safeIdentifier)
} else {
t.logger.Debugf("Dropping X-Forwarded-User header: identifier failed claim sanitization")
}
req.Header.Set("X-Forwarded-User", p.Identifier)
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
if identifierOK {
req.Header.Set("X-Auth-Request-User", safeIdentifier)
} else {
t.logger.Debugf("Dropping X-Auth-Request-User header: identifier failed claim sanitization")
}
req.Header.Set("X-Auth-Request-User", p.Identifier)
if p.IDToken != "" {
req.Header.Set("X-Auth-Request-Token", p.IDToken)
}
@@ -807,21 +747,8 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques
continue
}
headerValue := buf.String()
// Sanitize the rendered output: template inputs are claim-derived
// and attacker-influenceable, so reject control chars (header
// injection), bidi-override runes, the , ; = delimiters, and an
// over-long value. Fail closed by dropping the header rather than
// forwarding a tainted value. Do not log the value (it commonly
// carries the access token); log only name + reason.
if reason := headerValueReason(headerValue, headerTemplateMaxLen); reason != "" {
t.logger.Debugf("Dropping templated header %s: value failed sanitization (%s)", headerName, reason)
continue
}
req.Header.Set(headerName, headerValue)
// Do not log the value: templated headers commonly carry the access
// token (e.g. "Authorization: Bearer {{.AccessToken}}"), and logging
// it — even at debug — leaks credentials into logs.
t.logger.Debugf("Set templated header %s (%d bytes)", headerName, len(headerValue))
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
}
}
-404
View File
@@ -1,404 +0,0 @@
package traefikoidc
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/sessions"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// TestRank1_SessionCookieIsEncrypted verifies that the session cookie payload is
// AES-encrypted, not merely HMAC-signed. Regression test for the audit finding
// "session cookies signed but NOT encrypted": a single key left the stored OIDC
// tokens recoverable in plaintext from the raw cookie bytes.
func TestRank1_SessionCookieIsEncrypted(t *testing.T) {
const secret = "a-sufficiently-long-session-encryption-key"
authKey, encKey := deriveCookieKeys(secret)
if len(authKey) != 64 || len(encKey) != 32 {
t.Fatalf("expected 64-byte auth key and 32-byte enc key, got %d/%d", len(authKey), len(encKey))
}
if string(authKey) == string(encKey) {
t.Fatal("authentication and encryption keys must be independent")
}
const marker = "SUPER-SECRET-ACCESS-TOKEN-marker-value"
// Encode a session through the same two-key store the production code now
// builds (see NewSessionManager).
store := sessions.NewCookieStore(authKey, encKey)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
sess, err := store.New(req, "session")
if err != nil {
t.Fatalf("store.New failed: %v", err)
}
sess.Values["tok"] = marker
if err := sess.Save(req, rec); err != nil {
t.Fatalf("session save failed: %v", err)
}
var cookie *http.Cookie
for _, c := range rec.Result().Cookies() {
if c.Name == "session" {
cookie = c
}
}
if cookie == nil {
t.Fatal("no session cookie was set")
}
// The secret token must never appear in plaintext in the cookie value.
if strings.Contains(cookie.Value, marker) {
t.Error("marker token found in plaintext inside the session cookie value")
}
// A store holding only the authentication key (the previous behavior)
// must NOT be able to read the encrypted cookie — proving the payload is
// genuinely encrypted, not just signed.
signedOnly := sessions.NewCookieStore(authKey)
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
req2.AddCookie(cookie)
if _, derr := signedOnly.Get(req2, "session"); derr == nil {
t.Error("encrypted cookie should not be decodable without the encryption key")
}
// The full two-key store round-trips correctly.
req3 := httptest.NewRequest(http.MethodGet, "/", nil)
req3.AddCookie(cookie)
rt, derr := store.Get(req3, "session")
if derr != nil {
t.Fatalf("round-trip decode failed: %v", derr)
}
if got, _ := rt.Values["tok"].(string); got != marker {
t.Errorf("round-trip mismatch: got %q want %q", got, marker)
}
}
// TestRank2And6_InvalidConfigFailsClosed verifies that NewWithContext now calls
// Config.Validate() and fails closed on an empty or too-short session
// encryption key instead of silently substituting a public hardcoded key, and
// rejects other missing required fields. Regression test for "hardcoded default
// encryption key" + "Config.Validate() never called in production path".
func TestRank2And6_InvalidConfigFailsClosed(t *testing.T) {
base := func() *Config {
return &Config{
ProviderURL: "https://accounts.google.com",
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "this-is-a-valid-session-key-32b!",
RateLimit: 100,
}
}
// Sanity: a fully valid config still constructs.
p, err := NewWithContext(context.Background(), base(), nil, "valid")
if err != nil {
t.Fatalf("valid config should construct, got: %v", err)
}
if p != nil {
p.Close()
}
cases := []struct {
name string
mutate func(*Config)
}{
{"empty key", func(c *Config) { c.SessionEncryptionKey = "" }},
{"short key", func(c *Config) { c.SessionEncryptionKey = "tooshort" }},
{"missing providerURL", func(c *Config) { c.ProviderURL = "" }},
{"missing callbackURL", func(c *Config) { c.CallbackURL = "" }},
{"plaintext remote providerURL", func(c *Config) { c.ProviderURL = "http://accounts.google.com" }},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c := base()
tc.mutate(c)
plugin, err := NewWithContext(context.Background(), c, nil, tc.name)
if err == nil {
if plugin != nil {
plugin.Close()
}
t.Errorf("expected NewWithContext to reject config (%s), but it succeeded", tc.name)
}
})
}
}
// TestRank3_DiscoveredEndpointSSRFGuard verifies that endpoints from the
// provider discovery document are screened against SSRF targets before use.
func TestRank3_DiscoveredEndpointSSRFGuard(t *testing.T) {
tr := &TraefikOidc{}
blocked := []string{
"http://169.254.169.254/latest/meta-data/", // cloud metadata (link-local)
"http://[fe80::1]/jwks", // IPv6 link-local
"http://10.0.0.5/jwks", // private
"http://192.168.1.10/jwks", // private
"http://127.0.0.1/jwks", // loopback (allowLoopback=false)
"ftp://example.com/jwks", // disallowed scheme
}
for _, u := range blocked {
if err := tr.validateDiscoveredEndpoint(u, false); err == nil {
t.Errorf("expected discovered endpoint %q to be rejected", u)
}
}
allowed := []string{
"https://accounts.google.com/o/oauth2/v3/certs",
"https://www.googleapis.com/oauth2/v3/certs", // cross-domain JWKS must stay allowed
"", // empty optional endpoint
}
for _, u := range allowed {
if err := tr.validateDiscoveredEndpoint(u, false); err != nil {
t.Errorf("expected discovered endpoint %q to be allowed, got %v", u, err)
}
}
// Loopback is allowed only when the provider itself is loopback (dev/test).
if err := tr.validateDiscoveredEndpoint("http://127.0.0.1:8080/jwks", true); err != nil {
t.Errorf("loopback endpoint should be allowed when allowLoopback=true: %v", err)
}
// Private addresses are allowed when explicitly opted in.
trPriv := &TraefikOidc{allowPrivateIPAddresses: true}
if err := trPriv.validateDiscoveredEndpoint("http://10.0.0.5/jwks", false); err != nil {
t.Errorf("private endpoint should be allowed when allowPrivateIPAddresses=true: %v", err)
}
}
// TestRank4_IntrospectionHostPin verifies the host-equality check used to pin
// the credential-bearing introspection endpoint to the configured provider.
func TestRank4_IntrospectionHostPin(t *testing.T) {
if !sameHost("https://kc.example.com/realms/x", "https://kc.example.com/realms/x/protocol/openid-connect/token/introspect") {
t.Error("introspection on the same host as the provider should be accepted")
}
if sameHost("https://kc.example.com", "https://evil.example.net/introspect") {
t.Error("introspection on a different host must be rejected")
}
if sameHost("", "https://kc.example.com") || sameHost("https://kc.example.com", "") {
t.Error("empty URL must not be treated as a host match")
}
}
// TestRank5_OpenRedirectNeutralized verifies the helper the callback now applies
// to the stored incoming path forces a host-relative redirect target.
func TestRank5_OpenRedirectNeutralized(t *testing.T) {
cases := map[string]string{
"//evil.com/x": "/evil.com/x",
`/\evil.com`: "/evil.com",
"/legit/path": "/legit/path",
}
for in, want := range cases {
got := normalizeLogoutPath(in)
if got != want {
t.Errorf("normalizeLogoutPath(%q) = %q, want %q", in, got, want)
}
if strings.HasPrefix(got, "//") || strings.HasPrefix(got, `/\`) {
t.Errorf("normalizeLogoutPath(%q) = %q is still protocol-relative", in, got)
}
}
}
// TestRank14_ExcludedURLSegmentBoundary verifies excluded-URL matching is
// anchored at path-segment boundaries and cannot be widened into a bypass.
func TestRank14_ExcludedURLSegmentBoundary(t *testing.T) {
if !pathExcluded("/public", "/public") {
t.Error("exact match should be excluded")
}
if !pathExcluded("/public/page", "/public") {
t.Error("sub-path should be excluded")
}
if pathExcluded("/publicsecret", "/public") {
t.Error("/publicsecret must NOT be excluded by /public")
}
if pathExcluded("/public-admin", "/public") {
t.Error("/public-admin must NOT be excluded by /public")
}
if !pathExcluded("/health", "/health/") {
t.Error("trailing-slash config should still match the exact path")
}
if pathExcluded("/anything", "/") {
t.Error("root exclusion must not match arbitrary paths")
}
if !pathExcluded("/", "/") {
t.Error("root exclusion should match the root path")
}
}
// TestRank15_ForwardedHostSanitized verifies a crafted X-Forwarded-Host cannot
// inject CRLF, smuggle a second host, or otherwise poison the derived host.
func TestRank15_ForwardedHostSanitized(t *testing.T) {
mk := func(xfh string) *http.Request {
r := httptest.NewRequest(http.MethodGet, "http://real.example.com/x", nil)
r.Host = "real.example.com"
if xfh != "" {
r.Header.Set("X-Forwarded-Host", xfh)
}
return r
}
if got := utils.DetermineHost(mk("ext.example.com")); got != "ext.example.com" {
t.Errorf("clean X-Forwarded-Host should be honored, got %q", got)
}
if got := utils.DetermineHost(mk("a.example.com, evil.com")); got != "a.example.com" {
t.Errorf("multi-value X-Forwarded-Host should use first host only, got %q", got)
}
for _, bad := range []string{"evil.com\r\nSet-Cookie: x=1", "evil.com /x", " "} {
if got := utils.DetermineHost(mk(bad)); got != "real.example.com" {
t.Errorf("malformed X-Forwarded-Host %q should fall back to req.Host, got %q", bad, got)
}
}
}
// TestRank11_TransportPoolTLSIsolationAtLimit verifies that, once the client
// limit is reached, the transport pool reuses an existing transport only when
// its TLS settings match the caller's, and never hands back a transport built
// with different TLS trust settings.
func TestRank11_TransportPoolTLSIsolationAtLimit(t *testing.T) {
pool := &SharedTransportPool{
transports: make(map[string]*sharedTransport),
maxConns: 20,
maxClients: 5,
}
strict := DefaultHTTPClientConfig() // InsecureSkipVerify = false
t1 := pool.GetOrCreateTransport(strict)
if t1 == nil {
t.Fatal("expected a transport for the strict config")
}
// Saturate the client limit so subsequent calls hit the fallback path.
atomic.StoreInt32(&pool.clientCount, pool.maxClients)
// Same TLS settings, different (non-TLS) connection limit: safe to reuse.
sameTLS := DefaultHTTPClientConfig()
sameTLS.MaxConnsPerHost = 99
if got := pool.GetOrCreateTransport(sameTLS); got != t1 {
t.Error("at the limit a TLS-compatible config should reuse the existing transport")
}
// Different TLS settings (InsecureSkipVerify): must NOT reuse the strict
// transport — returning nil lets the caller fall back to a verifying default.
insecure := DefaultHTTPClientConfig()
insecure.InsecureSkipVerify = true
if got := pool.GetOrCreateTransport(insecure); got == t1 {
t.Error("at the limit a config with different TLS settings must not reuse the strict transport")
}
}
// TestRank9_RedisFingerprint verifies divergent explicit Redis backends produce
// distinct fingerprints (used to warn about ignored cache config), while an
// absent or disabled Redis yields the empty (no-warning) fingerprint.
func TestRank9_RedisFingerprint(t *testing.T) {
if redisFingerprint(nil) != "" {
t.Error("nil config should yield an empty fingerprint")
}
if redisFingerprint(&Config{}) != "" {
t.Error("config without Redis should yield an empty fingerprint")
}
if redisFingerprint(&Config{Redis: &RedisConfig{Enabled: false, Address: "a:6379"}}) != "" {
t.Error("disabled Redis should yield an empty fingerprint")
}
a := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "a:6379", KeyPrefix: "p"}})
b := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "b:6379", KeyPrefix: "p"}})
if a == "" || a == b {
t.Errorf("distinct enabled backends must produce distinct non-empty fingerprints (%q vs %q)", a, b)
}
}
// TestRank10_TokenTypeCacheKeyNoCollision verifies that two different tokens
// sharing the same 32-character JWT header prefix are classified independently.
// The previous 32-char cache key would have collided and mis-classified them.
func TestRank10_TokenTypeCacheKeyNoCollision(t *testing.T) {
tr := &TraefikOidc{
tokenTypeCache: NewCache(),
suppressDiagnosticLogs: true,
clientID: "client",
}
// A header prefix longer than 32 chars, shared by both tokens.
prefix := "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ"
idJWT := &JWT{Header: map[string]interface{}{}, Claims: map[string]interface{}{"nonce": "n"}}
accessJWT := &JWT{Header: map[string]interface{}{"typ": "at+jwt"}, Claims: map[string]interface{}{}}
if !tr.detectTokenType(idJWT, prefix+".id.sig") {
t.Error("token with a nonce claim should be detected as an ID token")
}
if tr.detectTokenType(accessJWT, prefix+".access.sig") {
t.Error("access token (typ=at+jwt) must not be mis-classified as ID despite the shared 32-char prefix")
}
}
// TestRank12_LiveInstanceCounter verifies the process-global instance counter
// that gates teardown of shared singleton tasks.
func TestRank12_LiveInstanceCounter(t *testing.T) {
start := atomic.LoadInt32(&liveInstanceCount)
registerLiveInstance()
registerLiveInstance()
if got := atomic.LoadInt32(&liveInstanceCount); got != start+2 {
t.Fatalf("expected %d live instances, got %d", start+2, got)
}
if rem := unregisterLiveInstance(); rem != start+1 {
t.Errorf("expected %d remaining, got %d", start+1, rem)
}
if rem := unregisterLiveInstance(); rem != start {
t.Errorf("expected %d remaining, got %d", start, rem)
}
}
// TestRank13_CookieMaxAgeMatchesSessionLifetime verifies the cookie store's
// MaxAge (which bounds both the cookie Max-Age and the codec's cryptographic
// timestamp validity) is bound to the configured session lifetime rather than
// gorilla's 30-day default.
func TestRank13_CookieMaxAgeMatchesSessionLifetime(t *testing.T) {
maxAge := 2 * time.Hour
sm, err := NewSessionManager(strings.Repeat("k", 40), false, "", "", maxAge, NewLogger("error"))
if err != nil {
t.Fatalf("NewSessionManager failed: %v", err)
}
defer sm.cancel()
cs, ok := sm.store.(*sessions.CookieStore)
if !ok {
t.Fatal("session store is not a *sessions.CookieStore")
}
if got := cs.Options.MaxAge; got != int(maxAge.Seconds()) {
t.Errorf("cookie store MaxAge = %d, want %d (bound to sessionMaxAge)", got, int(maxAge.Seconds()))
}
}
// TestRank33And34_HeaderSanitizationDistinction verifies the two header sinks
// use the right strictness: free-form templated header VALUES (rank 34) permit
// , ; = (e.g. an opaque "Bearer <token>" or an LDAP-DN claim) but reject CR/LF,
// bidi, and over-length; claim values joined into delimited/identifier headers
// (rank 33) additionally reject , ; =.
func TestRank33And34_HeaderSanitizationDistinction(t *testing.T) {
// Rank 34 — free-form header value.
if headerValueReason("Bearer abc=def==", 8192) != "" {
t.Error("'=' must be allowed in a free-form header value (opaque bearer token)")
}
if headerValueReason("cn=user,ou=eng;dc=x", 8192) != "" {
t.Error("',;=' must be allowed in a free-form header value (e.g. an LDAP DN claim)")
}
if headerValueReason("evil"+string(rune(13))+string(rune(10))+"Injected: 1", 8192) == "" {
t.Error("CR/LF must be rejected in a header value (injection)")
}
if headerValueReason("toolong", 3) == "" {
t.Error("over-length value must be rejected")
}
// Rank 33 — claim value bound for a delimited/identifier header.
if _, ok := sanitizeHeaderClaimValue("admins,superadmins", 256); ok {
t.Error("a comma must be rejected in a value joined into a comma-delimited header")
}
if _, ok := sanitizeHeaderClaimValue("normal-user@example.com", 256); !ok {
t.Error("a clean identifier must pass claim sanitization")
}
if _, ok := sanitizeHeaderClaimValue("evil"+string(rune(13))+string(rune(10))+"X: 1", 256); ok {
t.Error("CR/LF must be rejected in a claim value")
}
}
+590
View File
@@ -0,0 +1,590 @@
package traefikoidc
import (
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
)
// SecurityEventType categorizes different types of security events
// that can occur during OIDC authentication and authorization flows.
type SecurityEventType string
// Security event types for monitoring and alerting
const (
// AuthFailure indicates a failed authentication attempt
AuthFailure SecurityEventType = "authentication_failure"
// TokenValidFailure indicates JWT token validation failed
TokenValidFailure SecurityEventType = "token_validation_failure"
// RateLimitHit indicates rate limiting was triggered
RateLimitHit SecurityEventType = "rate_limit_hit"
// SuspiciousActivity indicates potentially malicious behavior
SuspiciousActivity SecurityEventType = "suspicious_activity"
)
// DefaultSeverity returns the default severity level for each security event type.
// Severity levels are: low, medium, high.
func (t SecurityEventType) DefaultSeverity() string {
switch t {
case AuthFailure:
return "medium"
case TokenValidFailure:
return "medium"
case RateLimitHit:
return "low"
case SuspiciousActivity:
return "high"
default:
return "medium"
}
}
// IPFailureType returns a string identifier for categorizing failures
// by IP address for rate limiting and blocking decisions.
func (t SecurityEventType) IPFailureType() string {
switch t {
case AuthFailure:
return "auth_failure"
case TokenValidFailure:
return "token_failure"
case SuspiciousActivity:
return "suspicious"
default:
return "general"
}
}
// SecurityEvent represents a security-related event with comprehensive context.
// Contains timing information, IP address, user agent, request details,
// and custom event-specific data for security analysis and alerting.
type SecurityEvent struct {
// Timestamp when the event occurred
Timestamp time.Time `json:"timestamp"`
// Details contains event-specific additional information
Details map[string]interface{} `json:"details,omitempty"`
// Type categorizes the event (auth_failure, token_failure, etc.)
Type string `json:"type"`
// Severity indicates event importance (low, medium, high)
Severity string `json:"severity"`
// ClientIP is the source IP address of the request
ClientIP string `json:"client_ip"`
// UserAgent is the User-Agent header from the request
UserAgent string `json:"user_agent"`
// RequestPath is the requested URL path
RequestPath string `json:"request_path"`
// Message provides human-readable description of the event
Message string `json:"message"`
}
// SecurityMonitor provides comprehensive security monitoring for the OIDC middleware.
// It tracks failures by IP address, detects suspicious patterns, enforces
// rate limits, and can trigger custom security event handlers.
type SecurityMonitor struct {
ipFailures map[string]*IPFailureTracker
patternDetector *SuspiciousPatternDetector
logger *Logger
cleanupTask *BackgroundTask
eventHandlers []SecurityEventHandler
config SecurityMonitorConfig
ipMutex sync.RWMutex
}
// IPFailureTracker maintains failure statistics and blocking state for an IP address.
// Used for implementing progressive penalties and automatic IP blocking based on
// failure patterns, with support for different failure types for
// rate limiting and IP blocking decisions.
type IPFailureTracker struct {
// LastFailure timestamp of the most recent failure
LastFailure time.Time
// FirstFailure timestamp of the first failure in current window
FirstFailure time.Time
// BlockedUntil indicates when the IP block expires
BlockedUntil time.Time
// FailureTypes tracks counts by failure type
FailureTypes map[string]int64
// FailureCount total number of failures
FailureCount int64
// mutex protects concurrent access to tracker data
mutex sync.RWMutex
// IsBlocked indicates if this IP is currently blocked
IsBlocked bool
}
// SuspiciousPatternDetector identifies attack patterns that may indicate coordinated threats.
// Analyzes events across multiple time windows to detect rapid failures, distributed attacks,
// and persistent attack patterns that individual IP monitoring might miss.
type SuspiciousPatternDetector struct {
// recentEvents stores recent security events for analysis
recentEvents []SecurityEvent
// shortWindow defines time frame for rapid failure detection
shortWindow time.Duration
// mediumWindow defines time frame for distributed attack detection
mediumWindow time.Duration
// longWindow defines time frame for persistent attack detection
longWindow time.Duration
// rapidFailureThreshold triggers rapid failure alerts
rapidFailureThreshold int
// distributedAttackThreshold triggers distributed attack alerts
distributedAttackThreshold int
// persistentAttackThreshold triggers persistent attack alerts
persistentAttackThreshold int
// eventsMutex protects concurrent access to events
eventsMutex sync.RWMutex
}
// SecurityEventHandler defines the interface for processing security events.
// Implementations can log events, send alerts, update external systems,
// or trigger automated response actions.
type SecurityEventHandler interface {
// HandleSecurityEvent processes a security event
HandleSecurityEvent(event SecurityEvent)
}
// SecurityMonitorConfig contains configuration parameters for the security monitor.
// Controls thresholds, time windows, and behavior for security monitoring.
type SecurityMonitorConfig struct {
// MaxFailuresPerIP sets the failure threshold before blocking
MaxFailuresPerIP int `json:"max_failures_per_ip"`
// FailureWindowMinutes defines the time window for counting failures
FailureWindowMinutes int `json:"failure_window_minutes"`
// BlockDurationMinutes sets how long to block an IP
BlockDurationMinutes int `json:"block_duration_minutes"`
// RapidFailureThreshold triggers rapid failure detection
RapidFailureThreshold int `json:"rapid_failure_threshold"`
// CleanupIntervalMinutes sets cleanup frequency for old data
CleanupIntervalMinutes int `json:"cleanup_interval_minutes"`
RetentionHours int `json:"retention_hours"`
EnablePatternDetection bool `json:"enable_pattern_detection"`
EnableDetailedLogging bool `json:"enable_detailed_logging"`
LogSuspiciousOnly bool `json:"log_suspicious_only"`
}
// DefaultSecurityMonitorConfig returns a default configuration
func DefaultSecurityMonitorConfig() SecurityMonitorConfig {
return SecurityMonitorConfig{
MaxFailuresPerIP: 10,
FailureWindowMinutes: 15,
BlockDurationMinutes: 60,
EnablePatternDetection: true,
RapidFailureThreshold: 5,
EnableDetailedLogging: true,
LogSuspiciousOnly: false,
CleanupIntervalMinutes: 30,
RetentionHours: 24,
}
}
// NewSecurityMonitor creates a new security monitor instance
func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor {
sm := &SecurityMonitor{
ipFailures: make(map[string]*IPFailureTracker),
eventHandlers: make([]SecurityEventHandler, 0),
config: config,
logger: logger,
patternDetector: NewSuspiciousPatternDetector(),
}
sm.startCleanupRoutine()
return sm
}
// NewSuspiciousPatternDetector creates a new pattern detector
func NewSuspiciousPatternDetector() *SuspiciousPatternDetector {
return &SuspiciousPatternDetector{
shortWindow: 1 * time.Minute,
mediumWindow: 5 * time.Minute,
longWindow: 15 * time.Minute,
rapidFailureThreshold: 5,
distributedAttackThreshold: 20,
persistentAttackThreshold: 50,
recentEvents: make([]SecurityEvent, 0),
}
}
// RecordSecurityEvent is a generic method to record any type of security event
func (sm *SecurityMonitor) RecordSecurityEvent(
eventType SecurityEventType,
clientIP, userAgent, requestPath string,
message string,
details map[string]interface{},
trackIPFailure bool) {
event := SecurityEvent{
Type: string(eventType),
Severity: eventType.DefaultSeverity(),
Timestamp: time.Now(),
ClientIP: clientIP,
UserAgent: userAgent,
RequestPath: requestPath,
Message: message,
Details: details,
}
if trackIPFailure {
sm.recordIPFailure(clientIP, eventType.IPFailureType())
}
sm.processSecurityEvent(event)
}
// RecordAuthenticationFailure records an authentication failure event
func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) {
if details == nil {
details = make(map[string]interface{})
}
details["reason"] = reason
sm.RecordSecurityEvent(
AuthFailure,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Authentication failed: %s", reason),
details,
true,
)
}
// RecordTokenValidationFailure records a token validation failure
func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) {
details := map[string]interface{}{
"reason": reason,
}
if tokenPrefix != "" {
details["token_prefix"] = tokenPrefix
}
sm.RecordSecurityEvent(
TokenValidFailure,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Token validation failed: %s", reason),
details,
true,
)
}
// RecordRateLimitHit records when rate limiting is triggered
func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) {
details := map[string]interface{}{
"limit_type": "token_verification",
}
sm.RecordSecurityEvent(
RateLimitHit,
clientIP,
userAgent,
requestPath,
"Rate limit exceeded",
details,
true,
)
}
// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories
func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) {
if details == nil {
details = make(map[string]interface{})
}
details["activity_type"] = activityType
sm.RecordSecurityEvent(
SuspiciousActivity,
clientIP,
userAgent,
requestPath,
fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description),
details,
true,
)
}
// recordIPFailure tracks failures for a specific IP address
func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) {
sm.ipMutex.Lock()
defer sm.ipMutex.Unlock()
tracker, exists := sm.ipFailures[clientIP]
if !exists {
tracker = &IPFailureTracker{
FailureTypes: make(map[string]int64),
FirstFailure: time.Now(),
}
sm.ipFailures[clientIP] = tracker
}
tracker.mutex.Lock()
defer tracker.mutex.Unlock()
tracker.FailureCount++
tracker.LastFailure = time.Now()
tracker.FailureTypes[failureType]++
windowStart := time.Now().Add(-time.Duration(sm.config.FailureWindowMinutes) * time.Minute)
if tracker.FirstFailure.After(windowStart) && tracker.FailureCount >= int64(sm.config.MaxFailuresPerIP) {
if !tracker.IsBlocked {
tracker.IsBlocked = true
tracker.BlockedUntil = time.Now().Add(time.Duration(sm.config.BlockDurationMinutes) * time.Minute)
sm.logger.Errorf("IP %s blocked due to %d failures (types: %v)", clientIP, tracker.FailureCount, tracker.FailureTypes)
blockEvent := SecurityEvent{
Type: "ip_blocked",
Severity: "high",
Timestamp: time.Now(),
ClientIP: clientIP,
Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes),
Details: map[string]interface{}{
"failure_count": tracker.FailureCount,
"failure_types": tracker.FailureTypes,
"blocked_until": tracker.BlockedUntil,
},
}
sm.processSecurityEvent(blockEvent)
}
}
}
// IsIPBlocked checks if an IP address is currently blocked
func (sm *SecurityMonitor) IsIPBlocked(clientIP string) bool {
sm.ipMutex.RLock()
defer sm.ipMutex.RUnlock()
tracker, exists := sm.ipFailures[clientIP]
if !exists {
return false
}
tracker.mutex.RLock()
defer tracker.mutex.RUnlock()
if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) {
return true
}
if tracker.IsBlocked && time.Now().After(tracker.BlockedUntil) {
tracker.IsBlocked = false
sm.logger.Infof("IP %s automatically unblocked", clientIP)
}
return false
}
// processSecurityEvent processes a security event through all handlers and pattern detection
func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) {
if sm.config.EnablePatternDetection {
sm.patternDetector.AddEvent(event)
if patterns := sm.patternDetector.DetectSuspiciousPatterns(); len(patterns) > 0 {
if len(patterns) == 1 {
sm.logger.Errorf("Suspicious pattern detected: %s", patterns[0])
} else {
sm.logger.Errorf("Multiple suspicious patterns detected: %v", patterns)
}
for _, pattern := range patterns {
patternEvent := SecurityEvent{
Type: "suspicious_pattern",
Severity: "high",
Timestamp: time.Now(),
Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern),
Details: map[string]interface{}{
"pattern_type": pattern,
"trigger_event": event,
},
}
sm.handleSecurityEvent(patternEvent)
}
}
}
sm.handleSecurityEvent(event)
}
// handleSecurityEvent sends the event to all registered handlers
func (sm *SecurityMonitor) handleSecurityEvent(event SecurityEvent) {
if sm.config.EnableDetailedLogging && (!sm.config.LogSuspiciousOnly || event.Severity == "high") {
sm.logger.Infof("Security Event [%s/%s]: %s (IP: %s, Path: %s)",
event.Type, event.Severity, event.Message, event.ClientIP, event.RequestPath)
}
for _, handler := range sm.eventHandlers {
go handler.HandleSecurityEvent(event)
}
}
// AddEventHandler adds a security event handler
func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) {
sm.eventHandlers = append(sm.eventHandlers, handler)
}
// This is kept for API compatibility but doesn't collect actual metrics
func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} {
return map[string]interface{}{
"tracked_ips": 0,
}
}
// AddEvent adds an event to the pattern detector
func (spd *SuspiciousPatternDetector) AddEvent(event SecurityEvent) {
spd.eventsMutex.Lock()
defer spd.eventsMutex.Unlock()
spd.recentEvents = append(spd.recentEvents, event)
cutoff := time.Now().Add(-spd.longWindow)
var filteredEvents []SecurityEvent
for _, e := range spd.recentEvents {
if e.Timestamp.After(cutoff) {
filteredEvents = append(filteredEvents, e)
}
}
spd.recentEvents = filteredEvents
}
// DetectSuspiciousPatterns analyzes recent events for suspicious patterns
func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string {
spd.eventsMutex.RLock()
defer spd.eventsMutex.RUnlock()
var patterns []string
now := time.Now()
ipCounts := make(map[string]int)
shortWindowStart := now.Add(-spd.shortWindow)
for _, event := range spd.recentEvents {
if event.Timestamp.After(shortWindowStart) &&
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
ipCounts[event.ClientIP]++
}
}
for ip, count := range ipCounts {
if count >= spd.rapidFailureThreshold {
patterns = append(patterns, fmt.Sprintf("rapid_failures_from_ip_%s", ip))
}
}
mediumWindowStart := now.Add(-spd.mediumWindow)
uniqueFailingIPs := make(map[string]bool)
for _, event := range spd.recentEvents {
if event.Timestamp.After(mediumWindowStart) &&
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
uniqueFailingIPs[event.ClientIP] = true
}
}
if len(uniqueFailingIPs) >= spd.distributedAttackThreshold {
patterns = append(patterns, "distributed_attack_pattern")
}
longWindowStart := now.Add(-spd.longWindow)
persistentFailures := 0
for _, event := range spd.recentEvents {
if event.Timestamp.After(longWindowStart) &&
(event.Type == "authentication_failure" || event.Type == "token_validation_failure") {
persistentFailures++
}
}
if persistentFailures >= spd.persistentAttackThreshold {
patterns = append(patterns, "persistent_attack_pattern")
}
return patterns
}
// startCleanupRoutine starts the background cleanup routine
func (sm *SecurityMonitor) startCleanupRoutine() {
sm.cleanupTask = NewBackgroundTask(
"security-monitor-cleanup",
time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute,
sm.cleanup,
sm.logger)
sm.cleanupTask.Start()
}
// StopCleanupRoutine stops the background cleanup routine
func (sm *SecurityMonitor) StopCleanupRoutine() {
if sm.cleanupTask != nil {
sm.cleanupTask.Stop()
sm.cleanupTask = nil
}
}
// cleanup removes old tracking data
func (sm *SecurityMonitor) cleanup() {
sm.ipMutex.Lock()
defer sm.ipMutex.Unlock()
cutoff := time.Now().Add(-time.Duration(sm.config.RetentionHours) * time.Hour)
for ip, tracker := range sm.ipFailures {
tracker.mutex.RLock()
shouldRemove := tracker.LastFailure.Before(cutoff) && !tracker.IsBlocked
tracker.mutex.RUnlock()
if shouldRemove {
delete(sm.ipFailures, ip)
}
}
sm.logger.Debugf("Security monitor cleanup completed, tracking %d IPs", len(sm.ipFailures))
}
// ExtractClientIP extracts the client IP from the request, considering proxy headers
func ExtractClientIP(r *http.Request) string {
if xri := r.Header.Get("X-Real-IP"); xri != "" {
if net.ParseIP(xri) != nil {
return xri
}
}
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
if len(ips) > 0 {
ip := strings.TrimSpace(ips[0])
if net.ParseIP(ip) != nil {
return ip
}
}
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
// LoggingSecurityEventHandler logs security events to the standard logger
type LoggingSecurityEventHandler struct {
logger *Logger
}
// NewLoggingSecurityEventHandler creates a new logging event handler
func NewLoggingSecurityEventHandler(logger *Logger) *LoggingSecurityEventHandler {
return &LoggingSecurityEventHandler{logger: logger}
}
// HandleSecurityEvent implements SecurityEventHandler
func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
switch event.Severity {
case "high":
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
case "medium":
h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
case "low":
h.logger.Infof("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
default:
h.logger.Debugf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP)
}
}
+285
View File
@@ -0,0 +1,285 @@
package traefikoidc
import (
"net/http/httptest"
"strconv"
"testing"
"time"
)
func TestSecurityMonitor(t *testing.T) {
config := DefaultSecurityMonitorConfig()
config.MaxFailuresPerIP = 3
config.BlockDurationMinutes = 1 // 1 minute for testing
config.CleanupIntervalMinutes = 1
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
defer func() {
// Allow cleanup goroutine to finish
time.Sleep(150 * time.Millisecond)
}()
t.Run("Record authentication failure", func(t *testing.T) {
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "invalid credentials", nil)
// Should not be blocked after first failure
if monitor.IsIPBlocked("192.168.1.1") {
t.Error("IP should not be blocked after first failure")
}
})
t.Run("IP blocked after max failures", func(t *testing.T) {
// Record multiple failures
for i := 0; i < config.MaxFailuresPerIP; i++ {
monitor.RecordAuthenticationFailure("192.168.1.2", "test-agent", "/login", "invalid credentials", nil)
}
// Should be blocked now
if !monitor.IsIPBlocked("192.168.1.2") {
t.Error("IP should be blocked after max failures")
}
})
t.Run("Token validation failure", func(t *testing.T) {
// Just verify the method doesn't panic
monitor.RecordTokenValidationFailure("192.168.1.3", "test-agent", "/api", "invalid token", "abc123")
})
t.Run("Rate limit hit", func(t *testing.T) {
// Just verify the method doesn't panic
monitor.RecordRateLimitHit("192.168.1.4", "test-agent", "/api")
})
t.Run("Suspicious activity", func(t *testing.T) {
details := map[string]interface{}{"pattern": "unusual"}
// Just verify the method doesn't panic
monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details)
})
}
func TestSuspiciousPatternDetector(t *testing.T) {
detector := NewSuspiciousPatternDetector()
t.Run("Add events and detect patterns", func(t *testing.T) {
// Add multiple events from same IP
for i := 0; i < 10; i++ {
event := SecurityEvent{
Type: "authentication_failure",
ClientIP: "192.168.1.100",
Timestamp: time.Now(),
}
detector.AddEvent(event)
}
patterns := detector.DetectSuspiciousPatterns()
found := false
for _, p := range patterns {
if p == "rapid_failures_from_ip_192.168.1.100" {
found = true
break
}
}
if !found {
t.Error("Expected to detect rapid failure pattern")
}
})
t.Run("Detect distributed attack pattern", func(t *testing.T) {
// Add failures from many different IPs
for i := 0; i < 25; i++ {
event := SecurityEvent{
Type: "authentication_failure",
ClientIP: "192.168.1." + strconv.Itoa(100+i),
Timestamp: time.Now(),
}
detector.AddEvent(event)
}
patterns := detector.DetectSuspiciousPatterns()
found := false
for _, p := range patterns {
if p == "distributed_attack_pattern" {
found = true
break
}
}
if !found {
t.Error("Expected to detect distributed attack pattern")
}
})
}
func TestExtractClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
headers map[string]string
expectedIP string
}{
{
name: "Direct connection",
remoteAddr: "192.168.1.1:12345",
expectedIP: "192.168.1.1",
},
{
name: "X-Forwarded-For header",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 10.0.0.1"},
expectedIP: "203.0.113.1",
},
{
name: "X-Real-IP header",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{"X-Real-IP": "203.0.113.2"},
expectedIP: "203.0.113.2",
},
{
name: "Multiple headers - X-Real-IP takes precedence",
remoteAddr: "10.0.0.1:12345",
headers: map[string]string{
"X-Forwarded-For": "203.0.113.1",
"X-Real-IP": "203.0.113.2",
},
expectedIP: "203.0.113.2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.RemoteAddr = tt.remoteAddr
for key, value := range tt.headers {
req.Header.Set(key, value)
}
ip := ExtractClientIP(req)
if ip != tt.expectedIP {
t.Errorf("Expected IP %s, got %s", tt.expectedIP, ip)
}
})
}
}
func TestSecurityEventHandlers(t *testing.T) {
t.Run("Logging security event handler", func(t *testing.T) {
logger := NewLogger("debug")
handler := NewLoggingSecurityEventHandler(logger)
event := SecurityEvent{
Type: "authentication_failure",
ClientIP: "192.168.1.1",
Timestamp: time.Now(),
Message: "Test failure",
Severity: "medium",
}
// Should not panic
handler.HandleSecurityEvent(event)
})
// Metrics security event handler test removed as part of metrics cleanup
}
func TestSecurityMonitorEventHandlers(t *testing.T) {
config := DefaultSecurityMonitorConfig()
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
// Add event handler with proper synchronization
handlerCalled := make(chan bool, 1)
handler := &testSecurityEventHandler{
callback: func(event SecurityEvent) {
select {
case handlerCalled <- true:
default:
// Channel already has a value, don't block
}
},
}
monitor.AddEventHandler(handler)
monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "test failure", nil)
// Wait for event handler to be called with timeout
select {
case <-handlerCalled:
// Success - handler was called
case <-time.After(100 * time.Millisecond):
t.Error("Expected event handler to be called within timeout")
}
}
// Test helper for security event handler
type testSecurityEventHandler struct {
callback func(SecurityEvent)
}
func (h *testSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) {
h.callback(event)
}
func TestDefaultSecurityMonitorConfig(t *testing.T) {
config := DefaultSecurityMonitorConfig()
if config.MaxFailuresPerIP <= 0 {
t.Error("Expected positive MaxFailuresPerIP")
}
if config.BlockDurationMinutes <= 0 {
t.Error("Expected positive BlockDurationMinutes")
}
if config.CleanupIntervalMinutes <= 0 {
t.Error("Expected positive CleanupIntervalMinutes")
}
if config.FailureWindowMinutes <= 0 {
t.Error("Expected positive FailureWindowMinutes")
}
}
func TestSecurityMonitorCleanup(t *testing.T) {
config := DefaultSecurityMonitorConfig()
config.CleanupIntervalMinutes = 1
config.BlockDurationMinutes = 1
config.RetentionHours = 1
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
// Block an IP
for i := 0; i < config.MaxFailuresPerIP; i++ {
monitor.RecordAuthenticationFailure("192.168.1.99", "test-agent", "/login", "test", nil)
}
// Verify it's blocked
if !monitor.IsIPBlocked("192.168.1.99") {
t.Error("IP should be blocked")
}
// Wait a bit and check if it gets unblocked automatically
time.Sleep(100 * time.Millisecond)
// The IP should still be blocked since we haven't waited long enough
if !monitor.IsIPBlocked("192.168.1.99") {
t.Error("IP should still be blocked")
}
}
func TestSecurityEventTypes(t *testing.T) {
config := DefaultSecurityMonitorConfig()
logger := NewLogger("debug")
monitor := NewSecurityMonitor(config, logger)
// Test different event types - just verify they don't panic
monitor.RecordAuthenticationFailure("192.168.1.200", "test-agent", "/login", "invalid password", nil)
monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123")
monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api")
details := map[string]interface{}{"pattern": "test"}
monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details)
// Just verify GetSecurityMetrics doesn't panic
_ = monitor.GetSecurityMetrics()
}
+15 -63
View File
@@ -4,9 +4,7 @@ import (
"bytes"
"compress/gzip"
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
@@ -33,45 +31,6 @@ func constantTimeStringCompare(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// deriveCookieKeys derives an independent 64-byte HMAC authentication key and a
// 32-byte AES-256 encryption key from the operator-provided session encryption
// key using HKDF-SHA256 (RFC 5869).
//
// gorilla/securecookie only ENCRYPTS the cookie payload when a block
// (encryption) key is supplied; constructing the store with a single key leaves
// sessions signed-but-plaintext, so the OIDC access/refresh/ID tokens stored in
// the cookie are recoverable by anyone who can read the raw cookie bytes. Two
// independent keys are derived here so the cookie is both encrypted and
// authenticated. HKDF is implemented with stdlib hmac+sha256 so it runs under
// Traefik's yaegi interpreter, which may not export crypto/hkdf.
func deriveCookieKeys(secret string) (authKey, encKey []byte) {
okm := hkdfSHA256([]byte(secret), nil, []byte("traefikoidc session cookie keys v1"), 96)
return okm[:64], okm[64:96]
}
// hkdfSHA256 performs HKDF-Extract followed by HKDF-Expand (RFC 5869) using
// HMAC-SHA256 and returns length bytes of output keying material.
func hkdfSHA256(ikm, salt, info []byte, length int) []byte {
if len(salt) == 0 {
salt = make([]byte, sha256.Size)
}
// Extract: PRK = HMAC-SHA256(salt, IKM)
ext := hmac.New(sha256.New, salt)
ext.Write(ikm)
prk := ext.Sum(nil)
// Expand: T(i) = HMAC-SHA256(PRK, T(i-1) | info | i)
var out, t []byte
for i := byte(1); len(out) < length; i++ {
exp := hmac.New(sha256.New, prk)
exp.Write(t)
exp.Write(info)
exp.Write([]byte{i})
t = exp.Sum(nil)
out = append(out, t...)
}
return out[:length]
}
// min returns the minimum of two integers.
// This is a utility function used throughout the session management code.
// Parameters:
@@ -159,12 +118,12 @@ var knownSessionKeys = map[string]bool{
"id_token": true,
"user_identifier": true,
"authenticated": true,
"csrf": true,
"nonce": true,
"code_verifier": true,
"incoming_path": true,
"created_at": true,
"redirect_count": true,
"csrf": true,
"nonce": true,
"code_verifier": true,
"incoming_path": true,
"created_at": true,
"redirect_count": true,
}
// compressCombinedPayload compresses the combined session payload using gzip.
@@ -423,6 +382,7 @@ type SessionManager struct {
cancel context.CancelFunc
cookieDomain string
cookiePrefix string
cookiePath string
sessionMaxAge time.Duration
activeSessions int64
poolHits int64
@@ -464,13 +424,8 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
ctx, cancel := context.WithCancel(context.Background())
// Derive independent authentication + encryption keys so the session cookie
// is AES-256 encrypted and HMAC authenticated, not merely signed. See
// deriveCookieKeys: a single key would leave the stored tokens in plaintext.
authKey, encKey := deriveCookieKeys(encryptionKey)
sm := &SessionManager{
store: sessions.NewCookieStore(authKey, encKey),
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
cookieDomain: cookieDomain,
cookiePrefix: cookiePrefix,
@@ -481,14 +436,6 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
cancel: cancel,
}
// Bind the cookie codec's timestamp validity (and the cookie Max-Age) to the
// configured session lifetime instead of gorilla's 30-day default, so a
// stolen cookie is not cryptographically valid for up to 30 days regardless
// of the (possibly much shorter) configured sessionMaxAge (rank 13).
if cs, ok := sm.store.(*sessions.CookieStore); ok {
cs.MaxAge(int(sessionMaxAge.Seconds()))
}
// Initialize global memory monitoring (singleton)
sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger)
@@ -905,7 +852,12 @@ func (sm *SessionManager) EnhanceSessionSecurity(options *sessions.Options, r *h
}
options.HttpOnly = true
options.Path = "/" // Ensure cookies are available on all paths for OAuth flow
// Use configured cookie path (default "/" for backward compatibility)
cookiePath := sm.cookiePath
if cookiePath == "" {
cookiePath = "/"
}
options.Path = cookiePath
if sm.cookieDomain != "" {
options.Domain = sm.cookieDomain
@@ -1620,7 +1572,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
sd.sessionMutex.Lock()
sd.clearAllSessionData(r, true)
// Release the lock before calling Save to prevent deadlock
sd.sessionMutex.Unlock()
+25 -44
View File
@@ -5,7 +5,6 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
@@ -65,23 +64,30 @@ type Config struct {
// IdPs do not expose RT TTL on the wire, so this is intentionally a
// conservative heuristic; tune to match your provider configuration.
// Default 21600 (6h). Set to 0 to disable the check.
MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"`
SessionMaxAge int `json:"sessionMaxAge"`
RateLimit int `json:"rateLimit"`
OverrideScopes bool `json:"overrideScopes"`
DisableReplayDetection bool `json:"disableReplayDetection,omitempty"`
RequireTokenIntrospection bool `json:"requireTokenIntrospection,omitempty"`
AllowOpaqueTokens bool `json:"allowOpaqueTokens,omitempty"`
StrictAudienceValidation bool `json:"strictAudienceValidation,omitempty"`
EnablePKCE bool `json:"enablePKCE"`
ForceHTTPS bool `json:"forceHTTPS"`
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
StripAuthCookies bool `json:"stripAuthCookies,omitempty"`
EnableBackchannelLogout bool `json:"enableBackchannelLogout,omitempty"`
EnableFrontchannelLogout bool `json:"enableFrontchannelLogout,omitempty"`
BackchannelLogoutURL string `json:"backchannelLogoutURL,omitempty"`
FrontchannelLogoutURL string `json:"frontchannelLogoutURL,omitempty"`
MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"`
SessionMaxAge int `json:"sessionMaxAge"`
RateLimit int `json:"rateLimit"`
OverrideScopes bool `json:"overrideScopes"`
DisableReplayDetection bool `json:"disableReplayDetection,omitempty"`
RequireTokenIntrospection bool `json:"requireTokenIntrospection,omitempty"`
AllowOpaqueTokens bool `json:"allowOpaqueTokens,omitempty"`
StrictAudienceValidation bool `json:"strictAudienceValidation,omitempty"`
EnablePKCE bool `json:"enablePKCE"`
ForceHTTPS bool `json:"forceHTTPS"`
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
StripAuthCookies bool `json:"stripAuthCookies,omitempty"`
// CookiePath restricts session cookies to a specific path prefix instead of "/".
// When traefikoidc protects some but not all paths on a domain, set this to the
// middleware's path prefix (e.g. "/app-protegido") so the browser does not send
// the OIDC session cookies to unprotected paths — preventing "Request Header
// Or Cookie Too Large" (431) errors on those paths.
// Default "/" (all paths, current behaviour).
CookiePath string `json:"cookiePath,omitempty"`
EnableBackchannelLogout bool `json:"enableBackchannelLogout,omitempty"`
EnableFrontchannelLogout bool `json:"enableFrontchannelLogout,omitempty"`
BackchannelLogoutURL string `json:"backchannelLogoutURL,omitempty"`
FrontchannelLogoutURL string `json:"frontchannelLogoutURL,omitempty"`
// CACertPath is an optional filesystem path to a PEM-encoded CA bundle used
// to verify the OIDC provider's TLS certificate. Use this when the provider
// is signed by an internal/private CA that is not in the system trust store.
@@ -763,32 +769,7 @@ func validateTemplateSecure(templateStr string) error {
// Returns true if the URL is valid and secure (HTTPS), false otherwise.
func isValidSecureURL(s string) bool {
u, err := url.Parse(s)
if err != nil || u.Host == "" {
return false
}
if u.Scheme == "https" {
return true
}
// Permit plaintext HTTP only for loopback hosts (local development,
// in-cluster sidecar providers, tests). Loopback traffic never leaves the
// host, so it is not exposed to network MITM; remote providers must use
// HTTPS. Mirrors the RFC 8252 loopback allowance.
if u.Scheme == "http" && isLoopbackHost(u.Hostname()) {
return true
}
return false
}
// isLoopbackHost reports whether host is "localhost" or a loopback IP literal
// (127.0.0.0/8 or ::1).
func isLoopbackHost(host string) bool {
if strings.EqualFold(host, "localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil {
return ip.IsLoopback()
}
return false
return err == nil && u.Scheme == "https" && u.Host != ""
}
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
+11 -41
View File
@@ -106,9 +106,8 @@ func (rm *ResourceManager) GetCache(key string) interface{} {
case "jwk-cache":
cache = cacheManager.GetSharedJWKCache()
default:
// Generic cache implementation; bind cleanup goroutine to the manager's
// shutdown channel so it exits when the ResourceManager shuts down.
cache = newGenericCacheWithOwner(1*time.Hour, rm.logger, rm.shutdownChan)
// Generic cache implementation
cache = NewGenericCache(1*time.Hour, rm.logger)
}
rm.caches[key] = cache
@@ -264,19 +263,6 @@ func (rm *ResourceManager) cleanupInstance(instanceID string) {
// This is a hook for future instance-specific cleanup needs
}
// liveInstanceCount tracks the number of fully-constructed TraefikOidc plugin
// instances alive in this process. Process-global singleton tasks (such as the
// shared token-cleanup) must only be stopped when the LAST instance shuts down,
// otherwise one instance's teardown would disable them for all survivors.
var liveInstanceCount int32
// registerLiveInstance records a newly constructed plugin instance.
func registerLiveInstance() { atomic.AddInt32(&liveInstanceCount, 1) }
// unregisterLiveInstance records a plugin instance shutting down and returns the
// number of instances still alive afterwards.
func unregisterLiveInstance() int32 { return atomic.AddInt32(&liveInstanceCount, -1) }
// Shutdown gracefully shuts down all managed resources
func (rm *ResourceManager) Shutdown(ctx context.Context) error {
var err error
@@ -515,31 +501,20 @@ func (p *GoroutinePool) Shutdown(ctx context.Context) error {
// GenericCache provides a simple cache implementation for testing
type GenericCache struct {
data map[string]interface{}
// ownerStopChan, when non-nil, signals the cleanup goroutine to exit when
// the owning ResourceManager shuts down, so the goroutine cannot outlive it.
ownerStopChan <-chan struct{}
logger *Logger
stopChan chan struct{}
ttl time.Duration
mu sync.RWMutex
data map[string]interface{}
logger *Logger
stopChan chan struct{}
ttl time.Duration
mu sync.RWMutex
}
// NewGenericCache creates a new generic cache
func NewGenericCache(ttl time.Duration, logger *Logger) *GenericCache {
return newGenericCacheWithOwner(ttl, logger, nil)
}
// newGenericCacheWithOwner creates a generic cache whose cleanup goroutine also
// exits when ownerStopChan is closed (typically the ResourceManager shutdown
// channel), guaranteeing the goroutine is stoppable on shutdown.
func newGenericCacheWithOwner(ttl time.Duration, logger *Logger, ownerStopChan <-chan struct{}) *GenericCache {
cache := &GenericCache{
data: make(map[string]interface{}),
ttl: ttl,
logger: logger,
stopChan: make(chan struct{}),
ownerStopChan: ownerStopChan,
data: make(map[string]interface{}),
ttl: ttl,
logger: logger,
stopChan: make(chan struct{}),
}
// Start cleanup routine
@@ -595,11 +570,6 @@ func (gc *GenericCache) cleanupRoutine() {
gc.mu.Unlock()
case <-gc.stopChan:
return
case <-gc.ownerStopChan:
// Owning ResourceManager is shutting down; exit so the goroutine
// does not outlive its owner. A nil channel blocks forever, so this
// case is inert when no owner is set.
return
}
}
}
+15 -27
View File
@@ -296,12 +296,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
// Create a TraefikOidc instance with context
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
RateLimit: 100,
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
}
plugin, err := NewWithContext(ctx, config, nil, "test")
@@ -353,9 +350,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
initialGoroutines := runtime.NumGoroutine()
configs := []Config{
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100},
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
}
var plugins []*TraefikOidc
@@ -435,12 +432,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
for i := 0; i < 3; i++ {
ctx := context.Background()
config := &Config{
ProviderURL: mockServers[i].URL,
ClientID: fmt.Sprintf("client%d", i),
ClientSecret: fmt.Sprintf("secret%d", i),
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
RateLimit: 100,
ProviderURL: mockServers[i].URL,
ClientID: fmt.Sprintf("client%d", i),
ClientSecret: fmt.Sprintf("secret%d", i),
}
plugin, err := NewWithContext(ctx, config, nil, fmt.Sprintf("test-%d", i))
@@ -601,12 +595,9 @@ func TestBackwardCompatibility(t *testing.T) {
t.Run("LegacyNewFunction", func(t *testing.T) {
// Test that the original New function still works
config := &Config{
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
RateLimit: 100,
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
}
handler, err := New(context.Background(), nil, config, "test")
@@ -626,12 +617,9 @@ func TestBackwardCompatibility(t *testing.T) {
t.Run("ExistingAPICompatibility", func(t *testing.T) {
config := &Config{
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
RateLimit: 100,
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
}
handler, _ := New(context.Background(), nil, config, "test")
+142
View File
@@ -0,0 +1,142 @@
package traefikoidc
import (
"bytes"
"context"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
)
// pluginVersion is bumped manually on each release. Keep in sync with the
// most recent git tag (see `git tag --sort=-v:refname | head -1`).
const pluginVersion = "1.0.11"
const (
telemetryProject = "traefikoidc"
telemetryTimeout = 2 * time.Second
)
// telemetryEndpoint is intentionally a var rather than a const so the test
// suite in this package can retarget it at an httptest server. Production
// code never mutates it.
var telemetryEndpoint = "https://oss.raczylo.com/v1/ping"
// telemetryOnce guarantees a single anonymous "plugin loaded" ping per
// process lifetime. Traefik can instantiate a middleware many times per
// process (one per route using the plugin); the sync.Once gate keeps the
// fire-and-forget call from amplifying into many pings.
//
// Reset in tests via `telemetryOnce = sync.Once{}`.
var telemetryOnce sync.Once
// telemetryInflight tracks any background goroutine started by sendTelemetry.
// Tests Wait on it to drain in-flight goroutines before mutating package
// state. Production code never calls Wait — the goroutine is fire-and-forget.
var telemetryInflight sync.WaitGroup
// sendTelemetry fires one anonymous usage ping in the background. It is
// failproof by contract:
//
// - never blocks the caller
// - never panics (the goroutine recovers internally)
// - never returns errors
// - silently dropped on invalid input, env-driven opt-out, or network failure
//
// Opt-out is honored via any of:
//
// - DO_NOT_TRACK=1
// - OSS_TELEMETRY_DISABLED=1
// - TRAEFIKOIDC_DISABLE_TELEMETRY=1
//
// Yaegi note: this file deliberately avoids generics (atomic.Pointer[T]) and
// range-over-int (Go 1.22) so it interprets under any reasonably recent
// Traefik yaegi runtime.
func sendTelemetry(version string) {
telemetryOnce.Do(func() {
if telemetryDisabledByEnv() {
return
}
if !validTelemetryVersion(version) {
return
}
telemetryInflight.Add(1)
go func() {
defer telemetryInflight.Done()
defer func() { _ = recover() }()
doTelemetryPost(version)
}()
})
}
func telemetryDisabledByEnv() bool {
keys := []string{
"DO_NOT_TRACK",
"OSS_TELEMETRY_DISABLED",
"TRAEFIKOIDC_DISABLE_TELEMETRY",
}
for _, k := range keys {
v := strings.ToLower(strings.TrimSpace(os.Getenv(k)))
if v == "1" || v == "true" || v == "yes" || v == "on" {
return true
}
}
return false
}
// validTelemetryVersion mirrors the server-side regex ^[A-Za-z0-9.+_-]{1,32}$
// using a byte loop. No allocation, no regexp dependency.
//
// Yaegi note: written as an `||` chain rather than `switch{case A,B,C:}` —
// some yaegi releases mis-evaluate comma-separated case expressions in
// switch-true blocks, returning false for all inputs.
func validTelemetryVersion(v string) bool {
if len(v) == 0 || len(v) > 32 {
return false
}
for i := 0; i < len(v); i++ {
c := v[i]
ok := (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '.' || c == '+' || c == '_' || c == '-'
if !ok {
return false
}
}
return true
}
// doTelemetryPost builds the JSON body manually. The project name is a
// constant and the version is pre-validated against an ASCII-only allowlist,
// so direct concatenation needs no JSON escaping.
func doTelemetryPost(version string) {
body := make([]byte, 0, 96)
body = append(body, `{"project":"`...)
body = append(body, telemetryProject...)
body = append(body, `","version":"`...)
body = append(body, version...)
body = append(body, `","ts":`...)
body = strconv.AppendInt(body, time.Now().Unix(), 10)
body = append(body, '}')
ctx, cancel := context.WithTimeout(context.Background(), telemetryTimeout)
defer cancel()
url := telemetryEndpoint
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: telemetryTimeout}
resp, err := client.Do(req)
if err != nil {
return
}
_ = resp.Body.Close()
}
+167
View File
@@ -0,0 +1,167 @@
package traefikoidc
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// resetTelemetryState restores package-level mutable state so tests do not
// contaminate one another. The cleanup waits for any in-flight ping goroutine
// to finish before restoring telemetryEndpoint — without that drain step the
// goroutine and the cleanup would race on the var.
func resetTelemetryState(t *testing.T) {
t.Helper()
telemetryOnce = sync.Once{}
prev := telemetryEndpoint
t.Cleanup(func() {
telemetryInflight.Wait()
telemetryEndpoint = prev
telemetryOnce = sync.Once{}
})
}
func newTelemetryServer(t *testing.T, status int) (hits *int32, lastBody func() string) {
t.Helper()
var counter int32
var mu sync.Mutex
var body string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&counter, 1)
b, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
mu.Lock()
body = string(b)
mu.Unlock()
w.WriteHeader(status)
}))
telemetryEndpoint = srv.URL
t.Cleanup(srv.Close)
return &counter, func() string {
mu.Lock()
defer mu.Unlock()
return body
}
}
func TestValidTelemetryVersion(t *testing.T) {
good := []string{"1.2.3", "1.4.0-beta1", "2.0", "v1.0.0", "1.0.0+meta", "dev"}
for _, v := range good {
if !validTelemetryVersion(v) {
t.Errorf("validTelemetryVersion(%q) = false, want true", v)
}
}
bad := []string{"", "has space", "semi;colon", strings.Repeat("1", 33)}
for _, v := range bad {
if validTelemetryVersion(v) {
t.Errorf("validTelemetryVersion(%q) = true, want false", v)
}
}
}
func TestTelemetryDisabledByEnv(t *testing.T) {
for _, k := range []string{"DO_NOT_TRACK", "OSS_TELEMETRY_DISABLED", "TRAEFIKOIDC_DISABLE_TELEMETRY"} {
t.Run(k, func(t *testing.T) {
t.Setenv(k, "1")
if !telemetryDisabledByEnv() {
t.Fatalf("%s=1 should disable", k)
}
})
}
t.Run("falsy_values_do_not_disable", func(t *testing.T) {
t.Setenv("DO_NOT_TRACK", "0")
t.Setenv("OSS_TELEMETRY_DISABLED", "false")
t.Setenv("TRAEFIKOIDC_DISABLE_TELEMETRY", "no")
if telemetryDisabledByEnv() {
t.Fatal("falsy env values should not disable")
}
})
}
func TestSendTelemetry_FiresOnceAcrossManyCalls(t *testing.T) {
resetTelemetryState(t)
hits, lastBody := newTelemetryServer(t, http.StatusNoContent)
for i := 0; i < 50; i++ {
sendTelemetry("1.2.3")
}
telemetryInflight.Wait()
if got := atomic.LoadInt32(hits); got != 1 {
t.Fatalf("expected exactly 1 hit, got %d", got)
}
var payload struct {
Project string `json:"project"`
Version string `json:"version"`
Ts int64 `json:"ts"`
}
if err := json.Unmarshal([]byte(lastBody()), &payload); err != nil {
t.Fatalf("server received non-JSON body: %q (err: %v)", lastBody(), err)
}
if payload.Project != "traefikoidc" || payload.Version != "1.2.3" || payload.Ts <= 0 {
t.Fatalf("unexpected payload: %+v", payload)
}
}
func TestSendTelemetry_RespectsDisableEnv(t *testing.T) {
resetTelemetryState(t)
hits, _ := newTelemetryServer(t, http.StatusNoContent)
t.Setenv("DO_NOT_TRACK", "1")
sendTelemetry("1.2.3")
telemetryInflight.Wait()
if got := atomic.LoadInt32(hits); got != 0 {
t.Fatalf("DO_NOT_TRACK should suppress; got %d hits", got)
}
}
func TestSendTelemetry_DropsInvalidVersion(t *testing.T) {
resetTelemetryState(t)
hits, _ := newTelemetryServer(t, http.StatusNoContent)
sendTelemetry("has space")
telemetryInflight.Wait()
if got := atomic.LoadInt32(hits); got != 0 {
t.Fatalf("invalid version should suppress; got %d hits", got)
}
}
func TestSendTelemetry_DoesNotBlock(t *testing.T) {
resetTelemetryState(t)
// Hanging server proves the caller is never blocked. The 2s context
// timeout in doTelemetryPost ensures the goroutine eventually exits;
// resetTelemetryState's cleanup waits for that drain before restoring
// telemetryEndpoint so there is no race with this test's mutation.
hung := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
time.Sleep(5 * time.Second)
}))
t.Cleanup(hung.Close)
telemetryEndpoint = hung.URL
start := time.Now()
sendTelemetry("1.2.3")
if elapsed := time.Since(start); elapsed > 50*time.Millisecond {
t.Fatalf("sendTelemetry blocked for %v, expected near-instant return", elapsed)
}
}
func TestSendTelemetry_SurvivesServerError(t *testing.T) {
resetTelemetryState(t)
hits, _ := newTelemetryServer(t, http.StatusInternalServerError)
sendTelemetry("1.2.3")
telemetryInflight.Wait()
if got := atomic.LoadInt32(hits); got != 1 {
t.Fatalf("request should still reach server even on 500; got %d hits", got)
}
}
+14 -29
View File
@@ -21,16 +21,13 @@ type IntrospectionResponse struct {
Username string `json:"username,omitempty"`
TokenType string `json:"token_type,omitempty"`
Sub string `json:"sub,omitempty"`
// Aud holds the introspection audience. Per RFC 7662 it may be a single
// string or an array of strings, so it is decoded as interface{} and
// matched with verifyAudience (which handles both shapes).
Aud interface{} `json:"aud,omitempty"`
Iss string `json:"iss,omitempty"`
Jti string `json:"jti,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Nbf int64 `json:"nbf,omitempty"`
Active bool `json:"active"`
Aud string `json:"aud,omitempty"`
Iss string `json:"iss,omitempty"`
Jti string `json:"jti,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Nbf int64 `json:"nbf,omitempty"`
Active bool `json:"active"`
}
// introspectToken performs OAuth 2.0 Token Introspection (RFC 7662) for an opaque token.
@@ -123,7 +120,7 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err
// Parse response per RFC 7662 Section 2.2
var introspectionResp IntrospectionResponse
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&introspectionResp); err != nil {
if err := json.NewDecoder(resp.Body).Decode(&introspectionResp); err != nil {
return nil, fmt.Errorf("failed to decode introspection response: %w", err)
}
@@ -131,12 +128,6 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err
if t.introspectionCache != nil {
// Cache for a short duration or until token expiry (whichever is shorter)
cacheDuration := 5 * time.Minute
// When introspection is REQUIRED, operators expect near-real-time
// revocation; cap the positive-result cache so a token revoked at the
// provider cannot keep passing for the full 5 minutes (rank 8).
if t.requireTokenIntrospection && cacheDuration > 30*time.Second {
cacheDuration = 30 * time.Second
}
if introspectionResp.Exp > 0 {
expTime := time.Unix(introspectionResp.Exp, 0)
untilExp := time.Until(expTime)
@@ -206,18 +197,12 @@ func (t *TraefikOidc) validateOpaqueToken(token string) error {
}
}
// Validate audience if configured. When a distinct API audience is
// configured (audience != clientID), the introspection response MUST carry
// a matching audience. Fail closed on a missing or mismatched aud: a token
// whose audience cannot be confirmed must not be accepted, otherwise a
// token minted for a different audience would pass. aud may be a single
// string or an array of strings (RFC 7662); verifyAudience handles both.
if t.audience != "" && t.audience != t.clientID {
if resp.Aud == nil {
return fmt.Errorf("invalid audience: expected %s, introspection response has no audience", t.audience)
}
if err := verifyAudience(resp.Aud, t.audience); err != nil {
return fmt.Errorf("invalid audience: expected %s: %w", t.audience, err)
// Validate audience if configured
// Note: For opaque tokens, audience validation via introspection may be limited
// depending on what the introspection endpoint returns
if t.audience != "" && t.audience != t.clientID && resp.Aud != "" {
if resp.Aud != t.audience {
return fmt.Errorf("invalid audience: expected %s, got %s", t.audience, resp.Aud)
}
}
+6 -9
View File
@@ -5,8 +5,6 @@ package traefikoidc
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
@@ -214,13 +212,11 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
//
//nolint:gocognit,gocyclo // Complex token type detection with multiple provider-specific checks
func (t *TraefikOidc) detectTokenType(jwt *JWT, token string) bool {
// Key on a hash of the FULL token. The first 32 characters of a JWT are
// only the base64url-encoded header, which is identical for every token
// sharing the same alg+kid, so distinct tokens (e.g. an ID token and an
// access token from the same issuer) would otherwise collide on the cache
// key and be mis-classified.
sum := sha256.Sum256([]byte(token))
cacheKey := hex.EncodeToString(sum[:])
// Use first 32 chars of token as cache key (sufficient for uniqueness)
cacheKey := token
if len(token) > 32 {
cacheKey = token[:32]
}
// Check cache first
if t.tokenTypeCache != nil {
@@ -862,6 +858,7 @@ func (t *TraefikOidc) isAzureProvider() bool {
strings.Contains(issuerURL, "login.windows.net")
}
// startTokenCleanup starts background cleanup goroutines for cache maintenance.
// It runs periodic cleanup of token cache, JWK cache, and session chunks.
// Includes panic recovery to ensure stability.
+535 -5
View File
@@ -3,9 +3,7 @@ package traefikoidc
import (
"bytes"
"compress/gzip"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
@@ -887,9 +885,10 @@ func TestDetectTokenTypeCaching(t *testing.T) {
},
}
token := "test-token-for-caching-with-enough-characters-for-key"
// The cache key is a SHA-256 hash of the full token (collision-resistant).
sum := sha256.Sum256([]byte(token))
cacheKey := hex.EncodeToString(sum[:])
cacheKey := token
if len(token) > 32 {
cacheKey = token[:32]
}
result := tr.detectTokenType(jwt, token)
if !result {
@@ -912,6 +911,521 @@ func TestDetectTokenTypeCaching(t *testing.T) {
}
}
// =============================================================================
// TOKEN VALIDATOR TESTS
// =============================================================================
func TestNewTokenValidator(t *testing.T) {
validator := NewTokenValidator(nil)
if validator == nil {
t.Fatal("Expected non-nil token validator")
}
if validator.logger == nil {
t.Error("Expected logger to be initialized")
}
}
func TestNewTokenValidatorWithLogger(t *testing.T) {
logger := GetSingletonNoOpLogger()
validator := NewTokenValidator(logger)
if validator == nil {
t.Fatal("Expected non-nil token validator")
}
if validator.logger != logger {
t.Error("Expected provided logger to be used")
}
}
func TestValidateTokenEmpty(t *testing.T) {
validator := NewTokenValidator(nil)
result := validator.ValidateToken("", false)
if result.Valid {
t.Error("Expected invalid result for empty token")
}
if result.Error == nil {
t.Error("Expected error for empty token")
}
if !strings.Contains(result.Error.Error(), "empty") {
t.Errorf("Expected 'empty' in error, got: %v", result.Error)
}
}
func TestValidateTokenRequireJWT(t *testing.T) {
validator := NewTokenValidator(nil)
result := validator.ValidateToken("opaque_token_value_here", true)
if result.Valid {
t.Error("Expected invalid result for opaque token when JWT required")
}
if result.Error == nil {
t.Error("Expected error when JWT required but opaque token provided")
}
}
func TestValidateJWTValidFormat(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := createTestJWTSimple(claims)
result := validator.ValidateToken(token, false)
if !result.Valid {
t.Errorf("Expected valid result, got error: %v", result.Error)
}
if result.TokenType != "JWT" {
t.Errorf("Expected TokenType 'JWT', got %s", result.TokenType)
}
if result.Claims == nil {
t.Error("Expected claims to be parsed")
}
if result.Expiry == nil {
t.Error("Expected expiry to be extracted")
}
if result.IssuedAt == nil {
t.Error("Expected issued at to be extracted")
}
}
func TestValidateJWTExpiredToken(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"exp": time.Now().Add(-1 * time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
}
token := createTestJWTSimple(claims)
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for expired token")
}
if result.Error == nil {
t.Error("Expected error for expired token")
}
if !strings.Contains(result.Error.Error(), "expired") {
t.Errorf("Expected 'expired' in error, got: %v", result.Error)
}
}
func TestValidateJWTFutureIssuedAt(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"exp": time.Now().Add(2 * time.Hour).Unix(),
"iat": time.Now().Add(10 * time.Minute).Unix(),
}
token := createTestJWTSimple(claims)
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for future iat")
}
if result.Error == nil {
t.Error("Expected error for future iat")
}
if !strings.Contains(result.Error.Error(), "future") {
t.Errorf("Expected 'future' in error, got: %v", result.Error)
}
}
func TestValidateJWTNotBeforeClaim(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"exp": time.Now().Add(2 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"nbf": time.Now().Add(1 * time.Hour).Unix(),
}
token := createTestJWTSimple(claims)
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for nbf in future")
}
if result.Error == nil {
t.Error("Expected error for nbf in future")
}
if !strings.Contains(result.Error.Error(), "not yet valid") {
t.Errorf("Expected 'not yet valid' in error, got: %v", result.Error)
}
}
func TestValidateJWTInvalidFormat(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
token string
}{
{"single part", "eyJhbGciOiJIUzI1NiJ9"},
{"two parts", "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0In0"},
{"four parts", "part1.part2.part3.part4"},
{"empty part", "eyJhbGciOiJIUzI1NiJ9..signature"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ValidateToken(tt.token, true)
if result.Valid {
t.Error("Expected invalid result for malformed JWT")
}
if result.Error == nil {
t.Error("Expected error for malformed JWT")
}
})
}
}
func TestValidateOpaqueTokenValid(t *testing.T) {
validator := NewTokenValidator(nil)
token := "sk_live_abcdef123456GHIJKL789"
result := validator.ValidateToken(token, false)
if !result.Valid {
t.Errorf("Expected valid result, got error: %v", result.Error)
}
if result.TokenType != "Opaque" {
t.Errorf("Expected TokenType 'Opaque', got %s", result.TokenType)
}
}
func TestValidateOpaqueTokenTooShort(t *testing.T) {
validator := NewTokenValidator(nil)
token := "short"
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for short token")
}
if result.Error == nil {
t.Error("Expected error for short token")
}
if !strings.Contains(result.Error.Error(), "too short") {
t.Errorf("Expected 'too short' in error, got: %v", result.Error)
}
}
func TestValidateOpaqueTokenWithSpaces(t *testing.T) {
validator := NewTokenValidator(nil)
token := "this token has spaces in it"
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for token with spaces")
}
if result.Error == nil {
t.Error("Expected error for token with spaces")
}
if !strings.Contains(result.Error.Error(), "spaces") {
t.Errorf("Expected 'spaces' in error, got: %v", result.Error)
}
}
func TestValidateOpaqueTokenControlCharacters(t *testing.T) {
validator := NewTokenValidator(nil)
token := "token_with\x00control_char"
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for token with control characters")
}
if result.Error == nil {
t.Error("Expected error for token with control characters")
}
if !strings.Contains(result.Error.Error(), "control character") {
t.Errorf("Expected 'control character' in error, got: %v", result.Error)
}
}
func TestValidateOpaqueTokenInsufficientEntropy(t *testing.T) {
validator := NewTokenValidator(nil)
token := "aaaaaabbbbbbccccccdddd"
result := validator.ValidateToken(token, false)
if result.Valid {
t.Error("Expected invalid result for low entropy token")
}
if result.Error == nil {
t.Error("Expected error for low entropy token")
}
if !strings.Contains(result.Error.Error(), "entropy") {
t.Errorf("Expected 'entropy' in error, got: %v", result.Error)
}
}
func TestIsValidBase64URL(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
input string
expected bool
}{
{"valid uppercase", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", true},
{"valid lowercase", "abcdefghijklmnopqrstuvwxyz", true},
{"valid numbers", "0123456789", true},
{"valid dash", "abc-def", true},
{"valid underscore", "abc_def", true},
{"valid equals", "abc=", true},
{"invalid at sign", "abc@def", false},
{"invalid space", "abc def", false},
{"invalid plus", "abc+def", false},
{"invalid slash", "abc/def", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.isValidBase64URL(tt.input)
if result != tt.expected {
t.Errorf("Expected %v for %s, got %v", tt.expected, tt.input, result)
}
})
}
}
func TestExtractTime(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
claim interface{}
name string
expected bool
}{
{name: "float64", claim: float64(1609459200), expected: true},
{name: "int64", claim: int64(1609459200), expected: true},
{name: "int", claim: int(1609459200), expected: true},
{name: "string", claim: "not a timestamp", expected: false},
{name: "nil", claim: nil, expected: false},
{name: "map", claim: map[string]interface{}{}, expected: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.extractTime(tt.claim)
if tt.expected && result == nil {
t.Error("Expected non-nil time")
}
if !tt.expected && result != nil {
t.Error("Expected nil time")
}
})
}
}
func TestValidateTokenSize(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
token string
maxSize int
expectError bool
}{
{"within limit", "short_token", 20, false},
{"at limit", "exactly_twenty_c", 16, false},
{"exceeds limit", "this_token_is_too_long", 10, true},
{"empty token", "", 10, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateTokenSize(tt.token, tt.maxSize)
if tt.expectError && err == nil {
t.Error("Expected error for oversized token")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if err != nil && !strings.Contains(err.Error(), "exceeds") {
t.Errorf("Expected 'exceeds' in error, got: %v", err)
}
})
}
}
func TestExtractClaims(t *testing.T) {
validator := NewTokenValidator(nil)
claims := map[string]interface{}{
"sub": "user123",
"email": "user@example.com",
"exp": float64(1609459200),
}
token := createTestJWTSimple(claims)
extracted, err := validator.ExtractClaims(token)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if extracted == nil {
t.Fatal("Expected non-nil claims")
}
if extracted["sub"] != "user123" {
t.Errorf("Expected sub 'user123', got %v", extracted["sub"])
}
if extracted["email"] != "user@example.com" {
t.Errorf("Expected email 'user@example.com', got %v", extracted["email"])
}
}
func TestExtractClaimsInvalidFormat(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
token string
}{
{"single part", "onlyonepart"},
{"two parts", "two.parts"},
{"four parts", "one.two.three.four"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := validator.ExtractClaims(tt.token)
if err == nil {
t.Error("Expected error for invalid format")
}
if !strings.Contains(err.Error(), "invalid JWT format") {
t.Errorf("Expected 'invalid JWT format' in error, got: %v", err)
}
})
}
}
func TestCompareTokensEqual(t *testing.T) {
validator := NewTokenValidator(nil)
token1 := "secret_token_12345"
token2 := "secret_token_12345"
if !validator.CompareTokens(token1, token2) {
t.Error("Expected tokens to be equal")
}
}
func TestCompareTokensDifferent(t *testing.T) {
validator := NewTokenValidator(nil)
token1 := "secret_token_12345"
token2 := "secret_token_54321"
if validator.CompareTokens(token1, token2) {
t.Error("Expected tokens to be different")
}
}
func TestCompareTokensDifferentLength(t *testing.T) {
validator := NewTokenValidator(nil)
token1 := "short"
token2 := "much_longer_token"
if validator.CompareTokens(token1, token2) {
t.Error("Expected tokens to be different (different lengths)")
}
}
func TestCompareTokensEmpty(t *testing.T) {
validator := NewTokenValidator(nil)
token1 := ""
token2 := ""
if !validator.CompareTokens(token1, token2) {
t.Error("Expected empty tokens to be equal")
}
}
func TestValidateTokenMaliciousPayloads(t *testing.T) {
validator := NewTokenValidator(nil)
tests := []struct {
name string
token string
}{
{"sql injection attempt", "'; DROP TABLE users; --"},
{"xss attempt", "<script>alert('xss')</script>"},
{"path traversal", "../../../etc/passwd"},
{"null bytes", "token\x00with\x00nulls"},
{"unicode exploit", "token\u0000\u0001\u0002"},
{"extremely long", strings.Repeat("a", 100000)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.ValidateToken(tt.token, false)
if result.Valid {
if result.Claims != nil {
t.Logf("Token considered valid: %s", tt.name)
}
} else {
if result.Error == nil {
t.Error("Expected error for malicious payload")
}
}
})
}
}
// =============================================================================
// CONSOLIDATED TOKEN TESTS
// =============================================================================
@@ -1584,3 +2098,19 @@ func createTokenOfSize(baseToken string, targetSize int) string {
return baseToken
}
func createTestJWTSimple(claims map[string]interface{}) string {
header := map[string]interface{}{
"alg": "HS256",
"typ": "JWT",
}
headerJSON, _ := json.Marshal(header)
claimsJSON, _ := json.Marshal(claims)
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
signature := base64.RawURLEncoding.EncodeToString([]byte("fake_signature"))
return headerB64 + "." + claimsB64 + "." + signature
}
+2 -15
View File
@@ -149,15 +149,7 @@ func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bo
if rs.idToken != "" {
return t.validateTokenExpiryRS(rs, rs.idToken)
}
// No ID token to corroborate an access token we cannot verify
// (Azure nonce-bearing Graph access tokens carry a proprietary,
// client-unverifiable signature). Do NOT authenticate on an
// unverified token: refresh if a refresh token is available,
// otherwise force re-authentication.
if rs.refreshToken != "" {
return false, true, false
}
return false, false, true
return true, false, false
}
}
@@ -166,12 +158,7 @@ func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bo
if rs.refreshToken != "" {
return false, true, false
}
// Opaque access token, no ID token to corroborate it, and
// introspection was unavailable/disabled/errored (e.g.
// circuit-breaker open). There is nothing left to verify the token
// against, so fail closed and force re-authentication rather than
// trusting an unverified opaque token.
return false, false, true
return true, false, false
}
if err := t.verifyToken(rs.idToken); err != nil {
if strings.Contains(err.Error(), "token has expired") {
+263
View File
@@ -0,0 +1,263 @@
package traefikoidc
import (
"bytes"
"encoding/base64"
"fmt"
"strings"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/pool"
)
// TokenValidator provides unified token validation functionality
type TokenValidator struct {
logger *Logger
}
// NewTokenValidator creates a new token validator
func NewTokenValidator(logger *Logger) *TokenValidator {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &TokenValidator{
logger: logger,
}
}
// TokenValidationResult contains the result of token validation
type TokenValidationResult struct {
Error error
Claims map[string]interface{}
Expiry *time.Time
IssuedAt *time.Time
TokenType string
Valid bool
}
// ValidateToken performs comprehensive token validation
func (v *TokenValidator) ValidateToken(token string, requireJWT bool) TokenValidationResult {
result := TokenValidationResult{}
// Basic validation
if token == "" {
result.Error = fmt.Errorf("token is empty")
return result
}
// Check if it's a JWT or opaque token
dotCount := strings.Count(token, ".")
isJWT := dotCount == 2
if requireJWT && !isJWT {
result.Error = fmt.Errorf("token is not a valid JWT (found %d dots, expected 2)", dotCount)
return result
}
if isJWT {
return v.validateJWT(token)
} else {
return v.validateOpaqueToken(token)
}
}
// validateJWT validates a JWT token
func (v *TokenValidator) validateJWT(token string) TokenValidationResult {
result := TokenValidationResult{
TokenType: "JWT",
}
parts := strings.Split(token, ".")
if len(parts) != 3 {
result.Error = fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
return result
}
// Validate each part
for i, part := range parts {
if part == "" {
result.Error = fmt.Errorf("JWT part %d is empty", i)
return result
}
// Check for valid base64url characters
if !v.isValidBase64URL(part) {
result.Error = fmt.Errorf("JWT part %d contains invalid base64url characters", i)
return result
}
}
// Decode and parse claims
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
result.Error = fmt.Errorf("failed to decode JWT payload: %w", err)
return result
}
var claims map[string]interface{}
pm := pool.Get()
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
defer pm.PutJSONDecoder(decoder)
if err := decoder.Decode(&claims); err != nil {
result.Error = fmt.Errorf("failed to parse JWT claims: %w", err)
return result
}
result.Claims = claims
// Extract standard claims
if exp, ok := claims["exp"]; ok {
expTime := v.extractTime(exp)
if expTime != nil {
result.Expiry = expTime
// Check if expired
if time.Now().After(*expTime) {
result.Error = fmt.Errorf("token is expired (expired at %v)", expTime.Format(time.RFC3339))
return result
}
}
}
if iat, ok := claims["iat"]; ok {
iatTime := v.extractTime(iat)
if iatTime != nil {
result.IssuedAt = iatTime
// Check if issued in future
if iatTime.After(time.Now().Add(5 * time.Minute)) {
result.Error = fmt.Errorf("token issued in future (iat: %v)", iatTime.Format(time.RFC3339))
return result
}
}
}
// Check nbf (not before)
if nbf, ok := claims["nbf"]; ok {
nbfTime := v.extractTime(nbf)
if nbfTime != nil && time.Now().Before(*nbfTime) {
result.Error = fmt.Errorf("token not yet valid (nbf: %v)", nbfTime.Format(time.RFC3339))
return result
}
}
result.Valid = true
return result
}
// validateOpaqueToken validates an opaque token
func (v *TokenValidator) validateOpaqueToken(token string) TokenValidationResult {
result := TokenValidationResult{
TokenType: "Opaque",
}
// Check minimum length
if len(token) < 20 {
result.Error = fmt.Errorf("opaque token too short (length: %d)", len(token))
return result
}
// Check for spaces
if strings.Contains(token, " ") {
result.Error = fmt.Errorf("opaque token contains spaces")
return result
}
// Check for control characters
for i, char := range token {
if char < 32 || char == 127 {
result.Error = fmt.Errorf("opaque token contains control character at position %d", i)
return result
}
}
// Check entropy
if len(token) >= 20 {
uniqueChars := make(map[rune]bool)
for _, char := range token {
uniqueChars[char] = true
}
if len(uniqueChars) < 8 {
result.Error = fmt.Errorf("opaque token has insufficient entropy (unique chars: %d)", len(uniqueChars))
return result
}
}
result.Valid = true
return result
}
// isValidBase64URL checks if a string contains only valid base64url characters
func (v *TokenValidator) isValidBase64URL(s string) bool {
for _, char := range s {
if !((char >= 'A' && char <= 'Z') ||
(char >= 'a' && char <= 'z') ||
(char >= '0' && char <= '9') ||
char == '-' || char == '_' || char == '=') {
return false
}
}
return true
}
// extractTime extracts a time.Time from various claim formats
func (v *TokenValidator) extractTime(claim interface{}) *time.Time {
var timestamp int64
switch val := claim.(type) {
case float64:
timestamp = int64(val)
case int64:
timestamp = val
case int:
timestamp = int64(val)
default:
return nil
}
t := time.Unix(timestamp, 0)
return &t
}
// ValidateTokenSize checks if token size is within acceptable limits
func (v *TokenValidator) ValidateTokenSize(token string, maxSize int) error {
if len(token) > maxSize {
return fmt.Errorf("token exceeds maximum size (size: %d, max: %d)", len(token), maxSize)
}
return nil
}
// ExtractClaims extracts claims from a JWT without full validation
func (v *TokenValidator) ExtractClaims(token string) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %w", err)
}
var claims map[string]interface{}
pm := pool.Get()
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
defer pm.PutJSONDecoder(decoder)
if err := decoder.Decode(&claims); err != nil {
return nil, fmt.Errorf("failed to parse claims: %w", err)
}
return claims, nil
}
// CompareTokens safely compares two tokens for equality
func (v *TokenValidator) CompareTokens(token1, token2 string) bool {
if len(token1) != len(token2) {
return false
}
// Use constant-time comparison to prevent timing attacks
var result byte
for i := 0; i < len(token1); i++ {
result |= token1[i] ^ token2[i]
}
return result == 0
}
+8 -45
View File
@@ -603,28 +603,15 @@ func (c *UniversalCache) removeItem(key string, item *CacheItem) {
// evictOldest evicts the oldest item from the cache (must be called with lock held)
func (c *UniversalCache) evictOldest() {
elem := c.lruList.Back()
if elem == nil {
return
}
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
if item, exists := c.items[key]; exists && item.element == elem {
c.removeItem(key, item)
atomic.AddInt64(&c.evictions, 1)
if c.logger.IsDebug() {
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
if elem := c.lruList.Back(); elem != nil {
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
if item, exists := c.items[key]; exists {
c.removeItem(key, item)
atomic.AddInt64(&c.evictions, 1)
if c.logger.IsDebug() {
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
}
}
return
}
// Defensive forward-progress guard: the back node is dangling — its key is
// absent from c.items, or c.items[key] points at a newer node (a stale
// duplicate). Drop the node directly so an eviction loop
// (`for ... && c.lruList.Len() > 0`) is guaranteed to terminate and can
// never spin holding c.mu.Lock(). With the updateLocalCache replace-in-place
// fix this branch should be unreachable, but it makes the spin impossible.
c.lruList.Remove(elem)
if c.currentSize > 0 {
c.currentSize--
}
}
@@ -957,30 +944,6 @@ func (c *UniversalCache) updateLocalCache(key string, value interface{}, ttl tim
}
now := time.Now()
// Replace an existing entry in place: update the item and move its single
// list node to the front. Without this, a repeat populate of the same key
// (the per-request Get->backend-hit path) would PushFront a duplicate node
// and overwrite c.items[key], orphaning the previous node. Orphans inflate
// currentMemory/currentSize and, once eviction deletes the key, leave a
// Back() node whose key is absent from c.items — so evictOldest() spins
// while holding c.mu.Lock(): the 100%-CPU write-lock convoy seen in pprof.
// setLocal dedups the same way; evictOldest also guards any dangling node.
if existing, exists := c.items[key]; exists {
c.currentMemory -= existing.Size
c.lruList.Remove(existing.element)
existing.Value = value
existing.Size = size
existing.ExpiresAt = now.Add(ttl)
existing.LastAccessed = now
existing.AccessCount++
existing.element = c.lruList.PushFront(key)
c.currentMemory += size
return nil
}
item := &CacheItem{
Key: key,
Value: value,
-84
View File
@@ -1,84 +0,0 @@
package traefikoidc
import (
"testing"
"time"
)
// newOrphanTestCache builds a Token-type cache with background cleanup disabled
// so the test fully controls lruList/items state.
func newOrphanTestCache(maxMem int64) *UniversalCache {
return NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken,
DefaultTTL: time.Hour,
MaxSize: 1_000_000, // large: keep the size-branch out of the way
MaxMemoryBytes: maxMem,
EnableMemoryLimit: maxMem > 0,
SkipAutoCleanup: true,
EnableAutoCleanup: false,
})
}
// TestUpdateLocalCache_NoOrphanElements is the direct red test: repeatedly
// populating the SAME key via updateLocalCache (the per-request Get->backend-hit
// path) must NOT leave dangling lruList elements. Today updateLocalCache blindly
// PushFronts + overwrites c.items[key] without removing the prior element, so the
// list grows one orphan per call while items stays at 1 entry.
func TestUpdateLocalCache_NoOrphanElements(t *testing.T) {
c := newOrphanTestCache(0) // memory limit off: isolate the orphan, no eviction
const key = "same-key"
for range 5 {
if err := c.updateLocalCache(key, "v", time.Hour); err != nil {
t.Fatalf("updateLocalCache: %v", err)
}
}
c.mu.RLock()
listLen := c.lruList.Len()
itemCount := len(c.items)
c.mu.RUnlock()
if itemCount != 1 {
t.Fatalf("items: got %d want 1", itemCount)
}
if listLen != 1 {
t.Fatalf("ORPHAN BUG: lruList.Len()=%d but items=%d (one list element per key expected)", listLen, itemCount)
}
}
// TestUpdateLocalCache_EvictionTerminates is the convoy reproducer: once orphans
// for a key exist and the memory-eviction loop runs, evictOldest() deletes the
// key from items on the first eviction, after which every remaining orphan at
// Back() has a key absent from items -> evictOldest() no-ops while lruList.Len()>0
// stays true -> infinite loop while holding c.mu.Lock(). That is the 100%-CPU
// holder + write-lock convoy observed in pprof.
func TestUpdateLocalCache_EvictionTerminates(t *testing.T) {
c := newOrphanTestCache(0) // start with memory limit OFF to accumulate orphans
const key = "same-key"
// Build 3 same-key list elements (3 orphans, items={key}).
for range 3 {
if err := c.updateLocalCache(key, "v", time.Hour); err != nil {
t.Fatalf("seed updateLocalCache: %v", err)
}
}
// Arm the trap: tiny memory limit so the next call enters the eviction loop.
c.mu.Lock()
c.config.MaxMemoryBytes = 1
c.mu.Unlock()
done := make(chan struct{})
go func() {
_ = c.updateLocalCache(key, "v", time.Hour) // triggers the eviction loop
close(done)
}()
select {
case <-done:
// fix present: loop made forward progress and returned
case <-time.After(2 * time.Second):
t.Fatal("INFINITE LOOP: eviction loop did not terminate within 2s (orphan whose key was deleted is never removed from lruList)")
}
}
+1 -83
View File
@@ -19,7 +19,7 @@ import (
// - true if the URL should be excluded from authentication, false otherwise.
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
if pathExcluded(currentRequest, excludedURL) {
if strings.HasPrefix(currentRequest, excludedURL) {
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
@@ -27,31 +27,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
return false
}
// pathExcluded reports whether requestPath is covered by an excluded prefix at a
// natural boundary: an exact match, a sub-path ("/public" → "/public/x"), or a
// file extension ("/favicon" → "/favicon.ico"). It deliberately does NOT match
// an unrelated sibling such as "/publicsecret", so a configured exclusion can no
// longer be widened into an authentication bypass on a different resource.
func pathExcluded(requestPath, excluded string) bool {
excluded = strings.TrimRight(excluded, "/")
if excluded == "" {
// A "/" (root) exclusion only matches the root path, not everything.
return requestPath == "" || requestPath == "/"
}
if requestPath == excluded {
return true
}
if !strings.HasPrefix(requestPath, excluded) {
return false
}
switch requestPath[len(excluded)] {
case '/', '.':
return true
default:
return false
}
}
// buildAuthURL constructs the OIDC provider authorization URL.
// It builds the URL with all necessary parameters including client_id, scopes,
// PKCE parameters, and provider-specific parameters for Google and Azure.
@@ -313,63 +288,6 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error {
return nil
}
// validateDiscoveredEndpoint validates an endpoint URL obtained from the
// provider's OIDC/OAuth2 discovery document before the plugin issues any
// outbound request to it. A discovery document is attacker-influenced if the
// provider is malicious or its TLS is broken, so an unvalidated endpoint is an
// SSRF vector (e.g. jwks_uri or introspection_endpoint pointed at the cloud
// metadata service 169.254.169.254 or an internal host).
//
// Empty endpoints are allowed (they are optional). Link-local (which covers the
// 169.254.0.0/16 metadata range), multicast and unspecified addresses are
// always rejected. Private addresses are rejected unless allowPrivateIPAddresses
// is set. Loopback is rejected unless allowLoopback is true — which the caller
// sets only when the operator-configured providerURL is itself loopback (local
// development, in-cluster sidecars, tests), so production deployments pointed at
// a real provider still block loopback SSRF.
func (t *TraefikOidc) validateDiscoveredEndpoint(urlStr string, allowLoopback bool) error {
if urlStr == "" {
return nil
}
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format: %w", err)
}
if u.Scheme != "https" && u.Scheme != "http" {
return fmt.Errorf("disallowed URL scheme: %q", u.Scheme)
}
if u.Host == "" {
return fmt.Errorf("missing host in URL")
}
if ip := net.ParseIP(u.Hostname()); ip != nil {
switch {
case ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsMulticast() || ip.IsUnspecified():
return fmt.Errorf("endpoint host is a blocked address: %s", ip)
case ip.IsLoopback() && !allowLoopback:
return fmt.Errorf("endpoint host is a loopback address: %s", ip)
case ip.IsPrivate() && !t.allowPrivateIPAddresses:
return fmt.Errorf("endpoint host is a private address: %s", ip)
}
}
if strings.Contains(u.Path, "..") {
return fmt.Errorf("path traversal detected in URL path")
}
return nil
}
// sameHost reports whether two URLs share the same host:port (case-insensitive).
// Used to pin the credential-bearing introspection endpoint to the operator-
// configured provider so a poisoned discovery document cannot redirect the
// client secret to an attacker-controlled host.
func sameHost(a, b string) bool {
ua, erra := url.Parse(a)
ub, errb := url.Parse(b)
if erra != nil || errb != nil || ua.Host == "" || ub.Host == "" {
return false
}
return strings.EqualFold(ua.Host, ub.Host)
}
// validateHost validates a hostname or IP address for security.
// It prevents access to localhost, private networks, and known metadata endpoints.
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
+3 -8
View File
@@ -135,7 +135,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
return false
}
domain := strings.ToLower(parts[1])
domain := parts[1]
_, domainAllowed := t.allowedUserDomains[domain]
if domainAllowed {
@@ -236,13 +236,8 @@ func (t *TraefikOidc) Close() error {
// Get resource manager for cleanup
rm := GetResourceManager()
// singleton-token-cleanup is a process-global task shared by every plugin
// instance. Only stop it when the LAST instance is shutting down;
// otherwise one instance's teardown (e.g. a single config reload) would
// kill chunked-session/token cleanup for all surviving instances (rank 12).
if unregisterLiveInstance() <= 0 {
_ = rm.StopBackgroundTask("singleton-token-cleanup") // best effort, last instance only
}
// Stop singleton tasks related to this instance
_ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup
// Stop metadata refresh task using same hash-based name as startMetadataRefresh
if t.providerURL != "" {
hash := sha256.Sum256([]byte(t.providerURL))
@@ -1 +0,0 @@
.docs
-36
View File
@@ -1,36 +0,0 @@
version: "2"
run:
timeout: 2m
linters:
default: none
enable:
- bodyclose
- errcheck
- errorlint
- gocritic
- gocyclo
- govet
- ineffassign
- misspell
- prealloc
- revive
- staticcheck
- unconvert
- unused
settings:
gocyclo:
min-complexity: 12
revive:
rules:
- name: var-naming
- name: indent-error-flow
- name: superfluous-else
- name: unused-parameter
- name: redefines-builtin-id
formatters:
enable:
- gofmt
- goimports
-42
View File
@@ -1,42 +0,0 @@
# Configuration for lukaszraczylo/semver-generator.
# Reference: https://github.com/lukaszraczylo/semver-generator
#
# Word matching is fuzzy + case-insensitive. The keywords below mirror the
# Conventional Commits prefixes used in this repo's git history. Same pattern
# as github.com/lukaszraczylo/go-telegram/.semver.yaml.
version: 1
# Respect existing v* tags as the version baseline. semver-generator finds
# the highest existing tag and bumps from there. With no tags yet, the first
# release computes from the empty base.
force:
existing: true
# Skip merge commits and machine-generated traffic that would otherwise
# spuriously bump the version.
blacklist:
- "Merge branch"
- "Merge pull request"
- "Merge remote-tracking branch"
- "go mod tidy"
wording:
patch:
- "fix"
- "chore"
- "docs"
- "test"
- "style"
- "refactor"
- "build"
- "ci"
- "perf"
minor:
- "feat"
major:
# Match only the canonical Conventional Commits trailer. The bare word
# "breaking" is too greedy under semver-generator's fuzzy match — it
# triggers on substrings inside a commit body and wrongly produces a
# major bump.
- "BREAKING CHANGE"
-122
View File
@@ -1,122 +0,0 @@
# oss-telemetry
A tiny Go client that fires one anonymous "this binary started" ping at a
central ingest endpoint. Designed to be embedded in your own open-source
projects so you can see approximate adoption and version spread without
collecting anything that could identify a user.
This is the **client library only**. The ingest endpoint, server-side
deduplication, rate limiting, and metrics are out of scope here.
## What it sends
A single HTTP `POST` per call to `Send`:
```json
{
"project": "my-tool",
"version": "1.2.3",
"ts": 1747782200
}
```
No identifiers, no IP, no machine info, no user data. The server dedupes
incoming requests; the client just fires and forgets.
## Failproof by design
- Never blocks the caller — work runs in a goroutine.
- Never panics — the goroutine recovers internally.
- Never returns errors — bad input and network failures are silently dropped.
- Never retries, never persists state, never reads back.
- 2-second hard timeout on every request.
- Zero third-party dependencies (Go stdlib only).
The endpoint is hardcoded and not overridable from consuming code, by design.
## Install
```bash
go get github.com/lukaszraczylo/oss-telemetry
```
Requires Go 1.22+.
## Usage
```go
package main
import (
"time"
telemetry "github.com/lukaszraczylo/oss-telemetry"
)
const version = "1.2.3"
func main() {
telemetry.Send("my-tool", version)
// ... your program runs ...
// Only needed for short-lived CLIs that may exit before the goroutine
// finishes its POST. Long-running services do not need this.
telemetry.Wait(2 * time.Second)
}
```
Call `Send` once at boot. Calling it more often just sends more pings; the
server deduplicates.
## Disabling telemetry
If you ship a binary that imports this library, link your users to this
section (`https://github.com/lukaszraczylo/oss-telemetry#disabling-telemetry`)
so they can find the opt-out paths.
Any one of these turns it off:
| Mechanism | How |
| ---------------------------------------- | ---------------------------------------------------------------- |
| Universal opt-out | `DO_NOT_TRACK=1` |
| Library-wide opt-out | `OSS_TELEMETRY_DISABLED=1` |
| Project-specific opt-out | `<UPPER_PROJECT>_DISABLE_TELEMETRY=1` |
| Programmatic (e.g. behind a `--no-telemetry` flag) | `telemetry.Disable()` before the first `Send` |
Project-specific env var derivation: uppercase the project name and replace
`-` with `_`. For `my-tool` the variable is `MY_TOOL_DISABLE_TELEMETRY`.
Truthy values: `1`, `true`, `yes`, `on` (case-insensitive). Anything else is
treated as "not set".
## Validation rules (silently dropped if violated)
- `project`: matches `^[a-z0-9-]+$`, length 164.
- `version`: matches `^[A-Za-z0-9.+_-]+$`, length 132.
Bad input is a no-op — the library never logs, never errors, never crashes.
## API
```go
// Fire a single ping in the background. Returns immediately.
func Send(project, version string)
// Suppress all subsequent Send calls in this process. Idempotent.
func Disable()
// Block until in-flight pings complete or timeout elapses, whichever first.
// Useful for short-lived CLI processes.
func Wait(timeout time.Duration)
```
## Testing
```bash
go test -race ./...
```
## License
Pick one before publishing. None bundled.
-367
View File
@@ -1,367 +0,0 @@
// Package telemetry sends anonymous usage pings for open-source Go projects.
//
// Wire format (POST application/json):
//
// {"project":"<name>","version":"<ver>","ts":<unix-seconds>}
//
// Design contract (failproof):
// - never blocks the caller (work happens in a goroutine)
// - never panics (background goroutine recovers internally)
// - never returns errors (silently no-ops on bad input or network failure)
// - never retries, never deduplicates, never persists state — the client
// fires a single ping and forgets; the server is responsible for
// deduplication, abuse protection, and aggregation
//
// Typical usage at program startup:
//
// telemetry.Send("my-tool", "1.2.3")
//
// For short-lived CLI processes that may exit before the goroutine finishes:
//
// telemetry.Send("my-tool", "1.2.3")
// defer telemetry.Wait(2 * time.Second)
//
// Disablement (any one of these suppresses pings):
// - environment variable DO_NOT_TRACK=1
// - environment variable OSS_TELEMETRY_DISABLED=1
// - environment variable <UPPER_PROJECT>_DISABLE_TELEMETRY=1
// (project name uppercased, dashes replaced with underscores)
// - calling telemetry.Disable() at runtime
package telemetry
import (
"bytes"
"context"
"net/http"
"os"
"runtime/debug"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
defaultEndpoint = "https://oss.raczylo.com/v1/ping"
httpTimeout = 2 * time.Second
maxProjectLen = 64
maxVersionLen = 32
)
// Yaegi note: this package is consumed by the traefikoidc Traefik plugin, which
// Traefik interprets with Yaegi (it vendors and interprets dependency source).
// It therefore avoids generic stdlib types (atomic.Pointer[T], atomic.Bool) and
// range-over-int (Go 1.22), which some Traefik/Yaegi runtimes cannot interpret.
// Endpoint mutation uses a mutex-guarded string; the disabled flag uses the
// function-based sync/atomic int32 API (atomic.LoadInt32/StoreInt32).
var (
// endpointURL holds the ingest URL. Production code never mutates it; the
// setter exists only so the test suite can retarget it at httptest servers
// while goroutines started by Send are still in flight.
endpointMu sync.RWMutex
endpointURL = defaultEndpoint
disabled int32 // 0 = enabled, 1 = disabled; accessed via sync/atomic only
inflight sync.WaitGroup
client = &http.Client{Timeout: httpTimeout}
)
func currentEndpoint() string {
endpointMu.RLock()
defer endpointMu.RUnlock()
return endpointURL
}
func setEndpointURL(u string) {
endpointMu.Lock()
endpointURL = u
endpointMu.Unlock()
}
// Send fires a single anonymous telemetry ping in the background and returns
// immediately. It never blocks, never panics, and never reports errors.
// Invalid inputs, disabled state, and network failures are silently dropped.
//
// Version strings are validated against a SemVer-ish shape that mirrors the
// receiver. An optional leading "v" or "V" is accepted and stripped before
// transmission so that callers can pass either "v1.2.3" or "1.2.3"; the
// wire form is always the unprefixed canonical version.
//
// Call once at program startup. Calling repeatedly will send repeated pings;
// the server is responsible for deduplication.
func Send(project, version string) {
if atomic.LoadInt32(&disabled) != 0 {
return
}
if isDisabledByEnv(project) {
return
}
if !validProject(project) || !validVersion(version) {
return
}
canonical := normalizeVersion(version)
inflight.Add(1)
go func() {
defer inflight.Done()
defer func() { _ = recover() }()
dispatch(project, canonical)
}()
}
// SendForModule is the recommended call form for Go libraries: it resolves
// the version automatically from Go's build info for the given module path
// so consumers do not need to maintain a hand-bumped version constant in
// source. Behaviour and contract are otherwise identical to [Send].
//
// Resolution order:
//
// 1. debug.ReadBuildInfo Deps entry for modulePath (authoritative when the
// library is consumed via go.mod);
// 2. debug.ReadBuildInfo Main when the library is itself the main module
// (e.g. running its own tests or examples);
// 3. fallback parameter, used only when build info is unavailable or
// unhelpful (replace directives, detached `go run`, ldflag override).
//
// Any leading "v" reported by build info is stripped to match the canonical
// wire form. Empty / "(devel)" build versions are skipped in favour of the
// next resolution source. Typical usage:
//
// telemetry.SendForModule("my-tool", "github.com/me/my-tool", "0.0.0-dev")
func SendForModule(project, modulePath, fallback string) {
Send(project, ResolveModuleVersion(modulePath, fallback))
}
// ResolveModuleVersion implements the version resolution used by
// SendForModule. Exposed for callers that need to format the resolved
// version (e.g. logging) without firing a ping.
func ResolveModuleVersion(modulePath, fallback string) string {
if info, ok := debug.ReadBuildInfo(); ok {
for _, d := range info.Deps {
if d != nil && d.Path == modulePath && isUsableBuildVersion(d.Version) {
return strings.TrimPrefix(d.Version, "v")
}
}
if info.Main.Path == modulePath && isUsableBuildVersion(info.Main.Version) {
return strings.TrimPrefix(info.Main.Version, "v")
}
}
return fallback
}
func isUsableBuildVersion(v string) bool {
return v != "" && v != "(devel)"
}
// Disable suppresses all subsequent Send calls in this process.
// Idempotent and safe to call from any goroutine.
func Disable() {
atomic.StoreInt32(&disabled, 1)
}
// Wait blocks until all in-flight pings have completed, or until timeout
// elapses — whichever comes first. Useful for short-lived CLI processes
// that may otherwise exit before the background goroutine finishes its POST.
//
// A non-positive timeout returns immediately.
func Wait(timeout time.Duration) {
if timeout <= 0 {
return
}
done := make(chan struct{})
go func() {
inflight.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(timeout):
}
}
func dispatch(project, version string) {
body := buildPayload(project, version, time.Now().Unix())
ctx, cancel := context.WithTimeout(context.Background(), httpTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, currentEndpoint(), bytes.NewReader(body))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return
}
_ = resp.Body.Close()
}
// buildPayload writes the JSON body without encoding/json. The validators
// restrict project and version to characters that never require JSON
// escaping, so direct concatenation is safe.
func buildPayload(project, version string, ts int64) []byte {
// Wrapper text plus 20 chars for a signed int64.
const overhead = len(`{"project":"","version":"","ts":}`) + 20
buf := make([]byte, 0, len(project)+len(version)+overhead)
buf = append(buf, `{"project":"`...)
buf = append(buf, project...)
buf = append(buf, `","version":"`...)
buf = append(buf, version...)
buf = append(buf, `","ts":`...)
buf = strconv.AppendInt(buf, ts, 10)
buf = append(buf, '}')
return buf
}
func validProject(p string) bool {
n := len(p)
if n == 0 || n > maxProjectLen {
return false
}
for i := 0; i < n; i++ {
c := p[i]
switch {
case c >= 'a' && c <= 'z',
c >= '0' && c <= '9',
c == '-':
default:
return false
}
}
return true
}
// validVersion accepts SemVer-ish version strings with an optional leading
// "v"/"V" prefix. Acceptable shape (after stripping the leading v):
//
// MAJOR[.MINOR[.PATCH]] ("-"prerelease)? ("+"build)?
//
// where MAJOR/MINOR/PATCH are ASCII digit sequences and the prerelease/build
// payloads are non-empty runs of [0-9A-Za-z.-]. This intentionally mirrors
// the receiver's version regex so junk like "dev" or "git-2026-05-22" never
// leaves the client (where it would only be rejected with HTTP 400 anyway).
func validVersion(v string) bool {
n := len(v)
if n == 0 || n > maxVersionLen {
return false
}
if v[0] == 'v' || v[0] == 'V' {
v = v[1:]
}
if len(v) == 0 {
return false
}
return checkSemverShape(v)
}
// normalizeVersion strips an optional leading "v"/"V" so the on-the-wire
// version matches the form stored server-side by the version refresher cron
// (which also strips the leading v from release tags). Callers may pass
// either "v1.2.3" or "1.2.3" — only the unprefixed form is transmitted.
func normalizeVersion(v string) string {
if len(v) > 0 && (v[0] == 'v' || v[0] == 'V') {
return v[1:]
}
return v
}
func checkSemverShape(s string) bool {
i := 0
if !readDigitRun(s, &i) {
return false
}
for groups := 0; groups < 2 && i < len(s) && s[i] == '.'; groups++ {
i++
if !readDigitRun(s, &i) {
return false
}
}
if i < len(s) && s[i] == '-' {
i++
if !readIdentRun(s, &i, '+') {
return false
}
}
if i < len(s) && s[i] == '+' {
i++
if !readIdentRun(s, &i, 0) {
return false
}
}
return i == len(s)
}
func readDigitRun(s string, i *int) bool {
start := *i
for *i < len(s) && s[*i] >= '0' && s[*i] <= '9' {
*i++
}
return *i > start
}
// readIdentRun consumes [0-9A-Za-z.-] until end-of-string or until `stop`
// is hit (stop=0 disables the early-stop check). Returns false if no
// characters were consumed (i.e. empty payload).
func readIdentRun(s string, i *int, stop byte) bool {
start := *i
for *i < len(s) {
c := s[*i]
if stop != 0 && c == stop {
break
}
valid := (c >= '0' && c <= '9') ||
(c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
c == '.' || c == '-'
if !valid {
return false
}
*i++
}
return *i > start
}
func isDisabledByEnv(project string) bool {
if truthy(os.Getenv("DO_NOT_TRACK")) {
return true
}
if truthy(os.Getenv("OSS_TELEMETRY_DISABLED")) {
return true
}
if project == "" {
return false
}
key := projectEnvKey(project)
return truthy(os.Getenv(key))
}
// projectEnvKey returns "<UPPER_PROJECT>_DISABLE_TELEMETRY" using a single
// allocation rather than chained strings.ToUpper(strings.ReplaceAll(...)).
func projectEnvKey(project string) string {
const suffix = "_DISABLE_TELEMETRY"
buf := make([]byte, 0, len(project)+len(suffix))
for i := 0; i < len(project); i++ {
c := project[i]
switch {
case c == '-':
c = '_'
case c >= 'a' && c <= 'z':
c -= 'a' - 'A'
}
buf = append(buf, c)
}
buf = append(buf, suffix...)
return string(buf)
}
func truthy(s string) bool {
switch strings.ToLower(strings.TrimSpace(s)) {
case "1", "true", "yes", "on":
return true
}
return false
}
-3
View File
@@ -24,9 +24,6 @@ github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
github.com/gorilla/sessions
# github.com/lukaszraczylo/oss-telemetry v0.2.3
## explicit; go 1.22
github.com/lukaszraczylo/oss-telemetry
# github.com/pmezard/go-difflib v1.0.0
## explicit
github.com/pmezard/go-difflib/difflib
-17
View File
@@ -1,17 +0,0 @@
package traefikoidc
// devPluginVersion is the placeholder carried by source-tree / local / test
// builds. Telemetry is suppressed while the plugin still reports this sentinel,
// so only stamped release builds emit a "plugin loaded" ping.
const devPluginVersion = "0.0.0-dev"
// traefikoidcPluginVersion is the released version of this plugin. It is stamped
// at release time by ./workflow-prepare.sh (invoked by the shared go-release
// workflow before GoReleaser builds and tags), which rewrites the string below
// to the computed semver.
//
// Traefik runs this plugin under Yaegi, where the version cannot be resolved
// from build info at runtime (debug.ReadBuildInfo sees Traefik's build graph,
// not the interpreted plugin). This build-stamped constant is therefore the
// single source of truth for the version reported by anonymous usage telemetry.
const traefikoidcPluginVersion = "0.0.0-dev"
-67
View File
@@ -1,67 +0,0 @@
#!/usr/bin/env bash
#
# workflow-prepare.sh — stamp the release version into version.go at build time.
#
# The shared go-release workflow (lukaszraczylo/shared-actions go-release.yaml)
# runs this script, if present, from the repository root BEFORE GoReleaser
# builds and tags. Traefik runs this plugin under Yaegi, where the version
# cannot be resolved from build info at runtime, so the released semver must be
# baked into source here.
#
# Version source — first non-empty wins:
# $VERSION $VERSION_TAG $SEMVER $NEW_VERSION $RELEASE_VERSION
# A leading "v"/"V" is stripped.
#
# NOTE: go-release.yaml @main does not yet pass the computed version into this
# step's environment. Add it to the "Run workflow prepare script" step, e.g.:
# env:
# VERSION: ${{ needs.version.outputs.version }} # bare, no leading v
#
# The shared workflow runs this script in its test, version AND release jobs,
# but only the release job has a computed version. So a missing version is a
# no-op (leave the dev sentinel) — NOT a hard failure, otherwise the test/version
# jobs would break. A malformed version that IS provided is a hard error. Wire
# the env only on the release job's prepare step (see header note above).
set -euo pipefail
FILE="version.go"
CONST="traefikoidcPluginVersion"
VER="${VERSION:-${VERSION_TAG:-${SEMVER:-${NEW_VERSION:-${RELEASE_VERSION:-}}}}}"
VER="${VER#v}"
VER="${VER#V}"
if [ -z "$VER" ]; then
if [ "${GITHUB_ACTIONS:-}" = "true" ]; then
echo "workflow-prepare: WARNING no version provided; leaving ${FILE} at the dev placeholder. If this is the release build, set 'env: VERSION: \${{ needs.version.outputs.version }}' on the release job's prepare step — otherwise the release ships 0.0.0-dev and emits no telemetry." >&2
else
echo "workflow-prepare: no version provided; leaving dev placeholder in ${FILE} (local build)"
fi
exit 0
fi
# Accept MAJOR[.MINOR[.PATCH]] with optional -prerelease / +build (semver-ish,
# matching the oss-telemetry receiver's validator).
if ! printf '%s' "$VER" | grep -Eq '^[0-9]+(\.[0-9]+){0,2}(-[0-9A-Za-z.-]+)?(\+[0-9A-Za-z.-]+)?$'; then
echo "workflow-prepare: ERROR version '${VER}' is not semver-shaped" >&2
exit 1
fi
if [ ! -f "$FILE" ]; then
echo "workflow-prepare: ERROR ${FILE} not found (run from repository root)" >&2
exit 1
fi
# Rewrite only the value of ${CONST}, anchored on the constant name so the
# sibling devPluginVersion sentinel is left untouched.
tmp="$(mktemp)"
sed -E "s/(${CONST}[[:space:]]*=[[:space:]]*\")[^\"]*(\")/\1${VER}\2/" "$FILE" > "$tmp"
mv "$tmp" "$FILE"
if ! grep -Eq "${CONST}[[:space:]]*=[[:space:]]*\"${VER}\"" "$FILE"; then
echo "workflow-prepare: ERROR failed to stamp version into ${FILE}" >&2
exit 1
fi
command -v gofmt >/dev/null 2>&1 && gofmt -w "$FILE"
echo "workflow-prepare: stamped ${CONST} = \"${VER}\" in ${FILE}"