diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 37a9175..1d58739 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -18,6 +18,6 @@ jobs: pr-checks: uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main with: - go-version: "1.24.11" + go-version: "1.25.x" coverage-threshold: 70 secrets: inherit diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cb1028a..0be7935 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,5 +19,5 @@ jobs: release: uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main with: - go-version: "1.24.11" + go-version: "1.25.x" secrets: inherit diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..33e9735 --- /dev/null +++ b/Makefile @@ -0,0 +1,61 @@ +# traefikoidc — Makefile +# Run `make help` for available targets. + +GO ?= go +GOPATH := $(shell $(GO) env GOPATH) +# Pin to the yaegi version bundled by the deployed Traefik so yaegi-validate +# tests the real interpreter, not a newer one that may support more. Traefik +# v3.7.1 vendors yaegi v0.16.1 (Go ~1.22 stdlib surface). Bump when Traefik is. +YAEGI_VERSION ?= v0.16.1 +TEST_TIMEOUT ?= 480s + +.DEFAULT_GOAL := help + +.PHONY: help +help: ## Show this help + @grep -hE '^[a-zA-Z0-9_-]+:.*## ' $(MAKEFILE_LIST) | awk 'BEGIN{FS=":.*## "}{printf " \033[36m%-16s\033[0m %s\n", $$1, $$2}' + +.PHONY: build +build: ## Compile all packages (native toolchain) + $(GO) build ./... + +.PHONY: fmt +fmt: ## Format sources with gofmt + gofmt -w $$(git ls-files '*.go' | grep -v '^vendor/') + +.PHONY: vet +vet: ## Run go vet + $(GO) vet ./... + +.PHONY: lint +lint: ## Run golangci-lint if available + @command -v golangci-lint >/dev/null 2>&1 && golangci-lint run ./... || echo "golangci-lint not installed; skipping" + +.PHONY: staticcheck +staticcheck: ## Run staticcheck (matches the CI "Static Analysis" job; catches U1000 unused, etc.) + @command -v staticcheck >/dev/null 2>&1 || { echo ">> installing staticcheck"; $(GO) install honnef.co/go/tools/cmd/staticcheck@latest; } + @GOFLAGS=-buildvcs=false $$(command -v staticcheck || echo "$(GOPATH)/bin/staticcheck") ./... + +.PHONY: test +test: ## Run the test suite + $(GO) test ./... -count=1 -timeout $(TEST_TIMEOUT) + +.PHONY: vendor +vendor: ## Refresh and vendor dependencies + $(GO) mod tidy && $(GO) mod vendor + +# yaegi-validate interprets the plugin under the yaegi interpreter the same way +# Traefik loads it. Native `go build`/`go test` use the standard compiler and do +# NOT catch yaegi-only incompatibilities (unsupported stdlib symbols, reflection +# edge cases). This target does. Importing the package forces yaegi to interpret +# every file in it plus its vendored deps; CreateConfig + New exercise the +# instantiation path. Pin YAEGI_VERSION to match Traefik's bundled yaegi if you +# need exact parity. +.PHONY: yaegi-validate +yaegi-validate: ## Verify the plugin loads under Traefik's yaegi interpreter + @command -v yaegi >/dev/null 2>&1 || { echo ">> installing yaegi@$(YAEGI_VERSION)"; $(GO) install github.com/traefik/yaegi/cmd/yaegi@$(YAEGI_VERSION); } + @echo ">> interpreting plugin under yaegi (as Traefik does)" + @DO_NOT_TRACK=1 GOFLAGS=-mod=vendor $$(command -v yaegi || echo "$(GOPATH)/bin/yaegi") run ./cmd/yaegicheck/main.go + +.PHONY: check +check: vet staticcheck test yaegi-validate ## vet + staticcheck + tests + yaegi load validation diff --git a/README.md b/README.md index 61182b3..3cdf8ab 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md). | `postLogoutRedirectURI` | `/` | Where to send users after logout. | | `scopes` | appended to `openid profile email` | Extra OAuth scopes. Set `overrideScopes: true` to replace defaults. | | `extraAuthParams` | none | Map of extra query parameters appended to the authorization request (e.g. `screen_hint: signup`, `login_hint`, `ui_locales`, `prompt`). Plugin-managed params (`client_id`, `state`, `nonce`, `redirect_uri`, `code_challenge`, `scope`, `response_type`, …) cannot be overridden. | -| `excludedURLs` | none | Prefix-matched paths that bypass auth. | +| `excludedURLs` | none | Paths that bypass auth, matched at a path-segment or file-extension boundary (e.g. `/public` matches `/public`, `/public/sub` and `/public.json`, but **not** `/publicsecret`). | | `allowedUserDomains` | none | Restrict to email domains. | | `allowedUsers` | none | Restrict to specific addresses (or claim values when `userIdentifierClaim != email`). | | `allowedRolesAndGroups` | none | Require any of these roles/groups from ID-token claims. | @@ -147,6 +147,18 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md). ## Production gotchas +### Upgrading from an earlier release + +- **Sessions are re-issued once.** Session cookies are now AES-256 encrypted + (previously signed only) and their cryptographic lifetime tracks + `sessionMaxAge` (previously a fixed 30 days). Existing cookies become invalid + on upgrade, so users re-authenticate one time. +- **Invalid configuration now fails closed at startup** instead of being + silently accepted: a `sessionEncryptionKey` shorter than 32 bytes, a + `rateLimit` below 10, a missing `callbackURL`, or a non-HTTPS remote + `providerURL` are rejected. Plaintext HTTP is permitted only for loopback + hosts (local development). + ### TLS termination at a load balancer `forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is @@ -166,6 +178,8 @@ detected" when the same token hits different replicas. Two options: For IdP-initiated logout (back/front-channel) in multi-replica setups, Redis is **required** so a logout on one instance invalidates sessions on the others. +Front-channel logout requests must include a matching `iss` query parameter; +requests that omit it are rejected with `400`. ### Multiple middleware instances on the same host diff --git a/auth_flow.go b/auth_flow.go index df96bd0..56f9d0c 100644 --- a/auth_flow.go +++ b/auth_flow.go @@ -182,6 +182,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, } codeVerifier := session.GetCodeVerifier() + if t.enablePKCE && codeVerifier == "" { + t.logger.Error("PKCE is enabled but code verifier is missing from session during callback") + t.sendErrorResponse(rw, req, "Authentication failed: PKCE verifier missing", http.StatusBadRequest) + return + } tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier) if err != nil { @@ -263,7 +268,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectPath := "/" if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath { - redirectPath = incomingPath + // Neutralize open-redirect payloads (e.g. //evil.com, /\evil.com) stored + // from the original request target before using it as the post-login + // redirect target. normalizeLogoutPath forces a host-relative path. + redirectPath = normalizeLogoutPath(incomingPath) } session.SetIncomingPath("") diff --git a/bearer_auth.go b/bearer_auth.go index 8796ca3..20bc0da 100644 --- a/bearer_auth.go +++ b/bearer_auth.go @@ -149,6 +149,94 @@ func parseBearerJOSEHeader(token string) *bearerError { return nil } +// headerClaimRuneReason reports why a rune is unsafe to inject into a request +// header value, or "" if the rune is acceptable. Shared core of the bearer-path +// identifier sanitizer and the cookie-path header claim sanitizer: rejects +// control chars (CRLF/header injection), Unicode bidi-override runes (RTL +// spoofing of admin UI / SIEM), and the delimiters , ; = (a comma in a group +// name would inject extra entries into a comma-joined header). +func headerClaimRuneReason(r rune) string { + if reason := headerInjectionRuneReason(r); reason != "" { + return reason + } + // The , ; = delimiters are only unsafe for values placed into delimited or + // list contexts (a comma-joined header, or an identifier downstreams may + // split). They are valid in arbitrary single header values, so this stricter + // check is used for the cookie-path identifier and the group/role list, NOT + // for free-form templated header output (see headerValueReason). + if r == ',' || r == ';' || r == '=' { + return "delimiter character" + } + return "" +} + +// headerInjectionRuneReason reports why a rune is unsafe in ANY HTTP header +// value, or "" if acceptable. Rejects control characters (CR/LF header +// injection) and Unicode bidi-override runes (RTL spoofing of admin UIs/SIEMs). +// Unlike headerClaimRuneReason it does NOT reject , ; = which are legitimate in +// free-form header values (e.g. an opaque "Authorization: Bearer "). +func headerInjectionRuneReason(r rune) string { + if unicode.IsControl(r) { + return "control character" + } + if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) { + return "bidi-override character" + } + return "" +} + +// headerValueReason reports why value is unsafe to forward as a free-form HTTP +// header value, or "" if acceptable. It rejects values over maxLen (maxLen<=0 +// disables the check) and values containing control or bidi-override runes, but +// permits , ; = (valid in header values). Empty is allowed. The reason string +// never includes the value, so it is safe to log. +func headerValueReason(value string, maxLen int) string { + if maxLen > 0 && len(value) > maxLen { + return "exceeds max length" + } + for _, r := range value { + if reason := headerInjectionRuneReason(r); reason != "" { + return reason + } + } + return "" +} + +// headerClaimValueReason reports why value is unsafe to inject into a +// downstream request header, or "" if it is acceptable. It rejects empty +// values, values exceeding maxLen (maxLen<=0 disables the length check), and +// values containing any rune rejected by headerClaimRuneReason. The reason +// string is safe to log (it never includes the value itself). +func headerClaimValueReason(value string, maxLen int) string { + if value == "" { + return "empty value" + } + if maxLen > 0 && len(value) > maxLen { + return "exceeds max length" + } + for _, r := range value { + if reason := headerClaimRuneReason(r); reason != "" { + return reason + } + } + return "" +} + +// sanitizeHeaderClaimValue validates a claim-derived value before it is +// injected into a downstream request header. It trims surrounding whitespace +// and fails closed (ok=false) on empty values, values exceeding maxLen +// (maxLen<=0 disables the length check), or values containing any rune rejected +// by headerClaimRuneReason. Used by the cookie/session path, which — unlike the +// bearer path — does not otherwise sanitize the principal identifier or the +// group/role strings joined into X-User-Groups / X-User-Roles. +func sanitizeHeaderClaimValue(raw string, maxLen int) (string, bool) { + value := strings.TrimSpace(raw) + if headerClaimValueReason(value, maxLen) != "" { + return "", false + } + return value, true +} + // sanitizeBearerIdentifier validates and trims a principal identifier before // it is injected into request headers. Layered defense: net/http will reject // CRLF on the wire too, but rejecting early gives clearer error logs and @@ -163,15 +251,8 @@ func sanitizeBearerIdentifier(raw string, maxLen int) (string, *bearerError) { return "", newBearerError(bearerErrInvalidIdentifier, "identifier exceeds max length") } for _, r := range identifier { - if unicode.IsControl(r) { - return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains control character") - } - // Unicode bidi-override range (RTL spoofing of admin UI / SIEM). - if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) { - return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains bidi-override character") - } - if r == ',' || r == ';' || r == '=' { - return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains delimiter character") + if reason := headerClaimRuneReason(r); reason != "" { + return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains "+reason) } } return identifier, nil @@ -342,7 +423,17 @@ func (b *bearerFailureTracker) recordSuccess(ip string) { } b.mu.Lock() defer b.mu.Unlock() - delete(b.entries, ip) + e, ok := b.entries[ip] + if !ok { + return + } + // Preserve an active penalty so a single success cannot wipe an in-effect + // lockout; only reset the counter when no penalty is active or it has expired. + now := time.Now() + if e.penaltyUntil.IsZero() || now.After(e.penaltyUntil) { + e.count = 0 + e.firstFailureAt = now + } } // clientIPForBearer returns the source IP used to key the failure tracker. diff --git a/bearer_auth_test.go b/bearer_auth_test.go index 118ef5f..3f253e8 100644 --- a/bearer_auth_test.go +++ b/bearer_auth_test.go @@ -67,31 +67,31 @@ func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc { t.Helper() sm := createTestSessionManager(t) oidc := &TraefikOidc{ - next: next, - logger: NewLogger("error"), - initComplete: make(chan struct{}), - sessionManager: sm, - firstRequestStarted: 1, + next: next, + logger: NewLogger("error"), + initComplete: make(chan struct{}), + sessionManager: sm, + firstRequestStarted: 1, metadataRefreshStartedAtomic: 1, - issuerURL: "https://issuer.example.com", - audience: "https://api.example.com", - clientID: "https://api.example.com", - tokenCache: NewTokenCache(), - excludedURLs: map[string]struct{}{"/favicon.ico": {}}, - allowedRolesAndGroups: map[string]struct{}{}, - limiter: rate.NewLimiter(rate.Every(time.Second), 1000), - ctx: context.Background(), - enableBearerAuth: true, - stripAuthorizationHeader: true, - bearerEmitWWWAuthenticate: true, - bearerOverridesCookie: false, - bearerIdentifierClaim: "sub", - maxIdentifierLength: 256, - maxTokenAge: 24 * time.Hour, - bearerFailureThreshold: 20, - bearerFailureWindow: 60 * time.Second, - bearerFailurePenalty: 60 * time.Second, - bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second), + issuerURL: "https://issuer.example.com", + audience: "https://api.example.com", + clientID: "https://api.example.com", + tokenCache: NewTokenCache(), + excludedURLs: map[string]struct{}{"/favicon.ico": {}}, + allowedRolesAndGroups: map[string]struct{}{}, + limiter: rate.NewLimiter(rate.Every(time.Second), 1000), + ctx: context.Background(), + enableBearerAuth: true, + stripAuthorizationHeader: true, + bearerEmitWWWAuthenticate: true, + bearerOverridesCookie: false, + bearerIdentifierClaim: "sub", + maxIdentifierLength: 256, + maxTokenAge: 24 * time.Hour, + bearerFailureThreshold: 20, + bearerFailureWindow: 60 * time.Second, + bearerFailurePenalty: 60 * time.Second, + bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second), } oidc.extractClaimsFunc = extractClaims close(oidc.initComplete) @@ -303,15 +303,33 @@ func TestBearerFailureTracker(t *testing.T) { if b, retry := tr.blocked(ip); !b || retry <= 0 { t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry) } - // Success clears the counter. + // A success while a penalty is active must NOT wipe the in-effect lockout + // (otherwise a single success could clear an attacker's penalty). tr.recordSuccess(ip) - if b, _ := tr.blocked(ip); b { - t.Fatalf("expected unblocked after success") + if b, _ := tr.blocked(ip); !b { + t.Fatalf("expected still blocked after success while penalty active") } // Other IPs are unaffected. if b, _ := tr.blocked("10.0.0.2"); b { t.Fatalf("unrelated IP should not be blocked") } + + // With an expired penalty, a success resets the counter so a subsequent + // sub-threshold failure does not immediately re-block. + tr2 := newBearerFailureTracker(3, 60*time.Second, 1*time.Millisecond) + const ip2 = "10.0.0.3" + for i := 0; i < 3; i++ { + tr2.recordFailure(ip2) + } + time.Sleep(5 * time.Millisecond) // let the short penalty expire + if b, _ := tr2.blocked(ip2); b { + t.Fatalf("expected unblocked after penalty expiry") + } + tr2.recordSuccess(ip2) // resets count since penalty has passed + tr2.recordFailure(ip2) // single failure, well below threshold + if b, _ := tr2.blocked(ip2); b { + t.Fatalf("expected unblocked: counter should have reset after success") + } } // ============================================================================= diff --git a/cache_manager.go b/cache_manager.go index 0508cbd..fe9cc0c 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -16,8 +16,9 @@ type CacheManager struct { } var ( - globalCacheManagerInstance *CacheManager - cacheManagerInitOnce sync.Once + globalCacheManagerInstance *CacheManager + cacheManagerInitOnce sync.Once + cacheManagerActiveFingerprint string ) // GetGlobalCacheManager returns a singleton CacheManager instance. @@ -29,7 +30,9 @@ func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager { // GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager { + fp := redisFingerprint(config) cacheManagerInitOnce.Do(func() { + cacheManagerActiveFingerprint = fp var redisConfig *RedisConfig var logger *Logger @@ -55,9 +58,27 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM manager: GetUniversalCacheManagerWithConfig(logger, redisConfig), } }) + // Warn loudly if a later instance asks for a DIFFERENT explicit Redis + // backend than the one that won initialization: the cache manager is a + // process-global singleton shared across plugin instances (yaegi), so this + // instance's divergent configuration is silently ignored, which would + // otherwise collapse cache/state isolation between routes (rank 9). + if fp != "" && cacheManagerActiveFingerprint != "" && fp != cacheManagerActiveFingerprint { + NewLogger(config.LogLevel).Errorf("cache manager already initialized with Redis backend %q; this instance's Redis backend %q is IGNORED (process-global singleton). Use a single consistent cache configuration across all routes.", cacheManagerActiveFingerprint, fp) + } return globalCacheManagerInstance } +// redisFingerprint returns a stable identifier for an explicitly-enabled Redis +// backend (address + key prefix), or "" when Redis is not explicitly enabled. +// Used to detect divergent cache configurations across plugin instances. +func redisFingerprint(config *Config) string { + if config == nil || config.Redis == nil || !config.Redis.Enabled { + return "" + } + return config.Redis.Address + "|" + config.Redis.KeyPrefix +} + // GetSharedTokenBlacklist returns the shared token blacklist cache func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface { cm.mu.RLock() diff --git a/cmd/yaegicheck/main.go b/cmd/yaegicheck/main.go new file mode 100644 index 0000000..78b0670 --- /dev/null +++ b/cmd/yaegicheck/main.go @@ -0,0 +1,46 @@ +//go:build ignore + +// Command yaegicheck verifies that the traefikoidc plugin can be imported and +// instantiated by the yaegi interpreter — the same way Traefik loads a plugin. +// +// It is run by `make yaegi-validate`. Importing the plugin package forces yaegi +// to interpret every source file in the package (and its vendored +// dependencies), so any construct yaegi cannot handle (unsupported stdlib +// symbol, reflection edge case, etc.) surfaces here rather than at Traefik load +// time. CreateConfig + New additionally exercise the instantiation path +// (session manager, cookie codec, caches, key derivation) under the interpreter. +package main + +import ( + "context" + "fmt" + "net/http" + "os" + + oidc "github.com/lukaszraczylo/traefikoidc" +) + +func main() { + cfg := oidc.CreateConfig() + cfg.ProviderURL = "https://accounts.google.com" + cfg.ClientID = "yaegi-check-client" + cfg.ClientSecret = "yaegi-check-secret" + cfg.CallbackURL = "/oauth2/callback" + cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef" + cfg.RateLimit = 100 + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h, err := oidc.New(context.Background(), next, cfg, "yaegi-check") + if err != nil { + fmt.Println("FAIL: New returned an error under yaegi:", err) + os.Exit(1) + } + if h == nil { + fmt.Println("FAIL: New returned a nil handler under yaegi") + os.Exit(1) + } + if closer, ok := h.(interface{ Close() error }); ok { + _ = closer.Close() + } + fmt.Println("OK: traefikoidc imported + CreateConfig + New succeeded under yaegi") +} diff --git a/coverage_boost_final_test.go b/coverage_boost_final_test.go index b8d291e..952a6b3 100644 --- a/coverage_boost_final_test.go +++ b/coverage_boost_final_test.go @@ -278,82 +278,6 @@ func TestHTTPClientProfiler_Methods_CoverageBoost(t *testing.T) { } } -// ============================================================================= -// SECURITY MONITORING TESTS -// ============================================================================= - -func TestSecurityMonitor_StopCleanupRoutine_CoverageBoost(t *testing.T) { - logger := NewLogger("info") - config := SecurityMonitorConfig{ - MaxFailuresPerIP: 5, - FailureWindowMinutes: 15, - BlockDurationMinutes: 30, - RapidFailureThreshold: 3, - CleanupIntervalMinutes: 60, - RetentionHours: 24, - EnablePatternDetection: true, - EnableDetailedLogging: false, - LogSuspiciousOnly: false, - } - - sm := NewSecurityMonitor(config, logger) - if sm == nil { - t.Fatal("Expected non-nil SecurityMonitor") - } - - // Start cleanup routine first (lowercase method) - sm.startCleanupRoutine() - - // Give it a moment to start - time.Sleep(50 * time.Millisecond) - - // Stop cleanup routine (public method) - sm.StopCleanupRoutine() - - // Stop again should be safe - sm.StopCleanupRoutine() -} - -func TestSecurityMonitor_MultipleHandlers_CoverageBoost(t *testing.T) { - logger := NewLogger("info") - config := SecurityMonitorConfig{ - MaxFailuresPerIP: 5, - FailureWindowMinutes: 15, - BlockDurationMinutes: 30, - RapidFailureThreshold: 3, - CleanupIntervalMinutes: 60, - RetentionHours: 24, - } - - sm := NewSecurityMonitor(config, logger) - - // Create handler - handler := &LoggingSecurityEventHandler{logger: logger} - - // Register handler using AddEventHandler - sm.AddEventHandler(handler) - - // Record a failure to trigger events - sm.RecordAuthenticationFailure("192.168.1.100", "test-agent", "/test", "test_failure", nil) -} - -func TestLoggingSecurityEventHandler_HandleSecurityEvent_AllSeverities_CoverageBoost(t *testing.T) { - logger := NewLogger("debug") - handler := &LoggingSecurityEventHandler{logger: logger} - - // Severity is a string in this implementation - events := []SecurityEvent{ - {Type: "test", Severity: "low", Message: "low severity"}, - {Type: "test", Severity: "medium", Message: "medium severity"}, - {Type: "test", Severity: "high", Message: "high severity"}, - {Type: "test", Severity: "critical", Message: "critical severity"}, - } - - for _, event := range events { - handler.HandleSecurityEvent(event) - } -} - // ============================================================================= // SESSION MANAGER TESTS // ============================================================================= diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 408cce4..bac456d 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -178,7 +178,7 @@ clientSecret: your-client-secret | `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) | | `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) | | `rateLimit` | int | `100` | Maximum requests per second | -| `excludedURLs` | []string | none | Paths that bypass authentication | +| `excludedURLs` | []string | none | Paths that bypass authentication, matched at a path-segment or file-extension boundary | | `revocationURL` | string | auto-discovered | Token revocation endpoint | | `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint | | `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow | diff --git a/dynamic_client_registration.go b/dynamic_client_registration.go index 83b50bb..7da699d 100644 --- a/dynamic_client_registration.go +++ b/dynamic_client_registration.go @@ -370,21 +370,6 @@ func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, res return r.saveCredentials(resp) } -// deleteCredentialsFromStore removes credentials from the configured storage backend -// Falls back to legacy file-based deletion if no store is configured -func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error { - // Use store if available - if r.store != nil { - return r.store.Delete(ctx, r.providerURL) - } - // Fallback to legacy file-based deletion - filePath := r.credentialsFilePath() - if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { - return err - } - return nil -} - // saveCredentials persists client credentials to a file (legacy method) func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error { filePath := r.credentialsFilePath() @@ -423,187 +408,3 @@ func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, return &resp, nil } - -// UpdateClientRegistration updates an existing client registration using RFC 7592 -// This requires the registration_client_uri and registration_access_token from the original registration -func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) { - r.mu.RLock() - cachedResp := r.registrationResponse - r.mu.RUnlock() - - if cachedResp == nil { - return nil, fmt.Errorf("no existing registration to update") - } - - if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" { - return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token") - } - - // Build update request - reqBody, err := r.buildRegistrationRequest() - if err != nil { - return nil, fmt.Errorf("failed to build update request: %w", err) - } - - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody)) - if err != nil { - return nil, fmt.Errorf("failed to create update request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken) - - // Execute request - resp, err := r.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("update request failed: %w", err) - } - defer resp.Body.Close() - - // Read response body - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("failed to read update response: %w", err) - } - - // Handle error responses - if resp.StatusCode != http.StatusOK { - var regError ClientRegistrationError - if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" { - return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription) - } - return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body)) - } - - // Parse successful response - var regResp ClientRegistrationResponse - if err := json.Unmarshal(body, ®Resp); err != nil { - return nil, fmt.Errorf("failed to parse update response: %w", err) - } - - // Update cache - r.mu.Lock() - r.registrationResponse = ®Resp - r.mu.Unlock() - - // Persist updated credentials if enabled - if r.config.PersistCredentials { - if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil { - r.logger.Errorf("Failed to persist updated credentials: %v", err) - } - } - - r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID) - return ®Resp, nil -} - -// ReadClientRegistration reads the current client registration using RFC 7592 -func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) { - r.mu.RLock() - cachedResp := r.registrationResponse - r.mu.RUnlock() - - if cachedResp == nil { - return nil, fmt.Errorf("no existing registration to read") - } - - if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" { - return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token") - } - - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil) - if err != nil { - return nil, fmt.Errorf("failed to create read request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken) - - // Execute request - resp, err := r.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("read request failed: %w", err) - } - defer resp.Body.Close() - - // Read response body - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - // Handle error responses - if resp.StatusCode != http.StatusOK { - var regError ClientRegistrationError - if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" { - return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription) - } - return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body)) - } - - // Parse successful response - var regResp ClientRegistrationResponse - if err := json.Unmarshal(body, ®Resp); err != nil { - return nil, fmt.Errorf("failed to parse read response: %w", err) - } - - return ®Resp, nil -} - -// DeleteClientRegistration deletes the client registration using RFC 7592 -func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error { - r.mu.RLock() - cachedResp := r.registrationResponse - r.mu.RUnlock() - - if cachedResp == nil { - return fmt.Errorf("no existing registration to delete") - } - - if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" { - return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token") - } - - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil) - if err != nil { - return fmt.Errorf("failed to create delete request: %w", err) - } - - req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken) - - // Execute request - resp, err := r.httpClient.Do(req) - if err != nil { - return fmt.Errorf("delete request failed: %w", err) - } - defer resp.Body.Close() - - // Handle error responses (204 No Content is success) - if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - var regError ClientRegistrationError - if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" { - return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription) - } - return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body)) - } - - // Clear cache - r.mu.Lock() - r.registrationResponse = nil - r.mu.Unlock() - - // Remove credentials from storage if persistence is enabled - if r.config.PersistCredentials { - if err := r.deleteCredentialsFromStore(ctx); err != nil { - r.logger.Errorf("Failed to remove credentials from storage: %v", err) - } - } - - r.logger.Info("Successfully deleted client registration") - return nil -} diff --git a/dynamic_client_registration_test.go b/dynamic_client_registration_test.go index 1efcf85..29c4b9a 100644 --- a/dynamic_client_registration_test.go +++ b/dynamic_client_registration_test.go @@ -735,258 +735,6 @@ func TestDCRConfigDefaults(t *testing.T) { } } -// TestUpdateClientRegistration tests the RFC 7592 client update functionality -func TestUpdateClientRegistration(t *testing.T) { - updateCalled := false - - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPut { - updateCalled = true - - // Verify authorization header - if r.Header.Get("Authorization") == "" { - t.Error("Missing Authorization header for update") - } - - resp := ClientRegistrationResponse{ - ClientID: "updated-client-id", - ClientSecret: "updated-client-secret", - RegistrationAccessToken: "new-access-token", - RegistrationClientURI: r.URL.String(), - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(resp) - } - })) - defer server.Close() - - dcrConfig := &DynamicClientRegistrationConfig{ - Enabled: true, - ClientMetadata: &ClientRegistrationMetadata{ - RedirectURIs: []string{"https://example.com/callback"}, - }, - } - - registrar := NewDynamicClientRegistrar( - server.Client(), - NewLogger("DEBUG"), - dcrConfig, - server.URL, - ) - - // Set up cached response with management credentials - registrar.mu.Lock() - registrar.registrationResponse = &ClientRegistrationResponse{ - ClientID: "original-client-id", - ClientSecret: "original-client-secret", - RegistrationAccessToken: "access-token", - RegistrationClientURI: server.URL + "/register/client123", - } - registrar.mu.Unlock() - - // Perform update - ctx := context.Background() - resp, err := registrar.UpdateClientRegistration(ctx) - - if err != nil { - t.Fatalf("Update failed: %v", err) - } - - if !updateCalled { - t.Error("Update endpoint was not called") - } - - if resp.ClientID != "updated-client-id" { - t.Errorf("Updated ClientID mismatch: got %s", resp.ClientID) - } -} - -// TestDeleteClientRegistration tests the RFC 7592 client deletion functionality -func TestDeleteClientRegistration(t *testing.T) { - deleteCalled := false - - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodDelete { - deleteCalled = true - w.WriteHeader(http.StatusNoContent) - } - })) - defer server.Close() - - tempDir := t.TempDir() - credentialsFile := filepath.Join(tempDir, "credentials.json") - - // Create a credentials file to test deletion - os.WriteFile(credentialsFile, []byte(`{"client_id":"test"}`), 0600) - - dcrConfig := &DynamicClientRegistrationConfig{ - Enabled: true, - PersistCredentials: true, - CredentialsFile: credentialsFile, - } - - registrar := NewDynamicClientRegistrar( - server.Client(), - NewLogger("DEBUG"), - dcrConfig, - server.URL, - ) - - // Set up cached response with management credentials - registrar.mu.Lock() - registrar.registrationResponse = &ClientRegistrationResponse{ - ClientID: "test-client-id", - RegistrationAccessToken: "access-token", - RegistrationClientURI: server.URL + "/register/client123", - } - registrar.mu.Unlock() - - // Perform delete - ctx := context.Background() - err := registrar.DeleteClientRegistration(ctx) - - if err != nil { - t.Fatalf("Delete failed: %v", err) - } - - if !deleteCalled { - t.Error("Delete endpoint was not called") - } - - // Verify cache is cleared - if registrar.GetCachedResponse() != nil { - t.Error("Cached response should be cleared after deletion") - } - - // Verify credentials file is deleted - if _, err := os.Stat(credentialsFile); !os.IsNotExist(err) { - t.Error("Credentials file should be deleted") - } -} - -// TestReadClientRegistration tests the RFC 7592 client read functionality -func TestReadClientRegistration(t *testing.T) { - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet { - resp := ClientRegistrationResponse{ - ClientID: "read-client-id", - ClientSecret: "read-client-secret", - RedirectURIs: []string{"https://example.com/callback"}, - ResponseTypes: []string{"code"}, - GrantTypes: []string{"authorization_code"}, - ApplicationType: "web", - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(resp) - } - })) - defer server.Close() - - dcrConfig := &DynamicClientRegistrationConfig{Enabled: true} - - registrar := NewDynamicClientRegistrar( - server.Client(), - NewLogger("DEBUG"), - dcrConfig, - server.URL, - ) - - // Set up cached response with management credentials - registrar.mu.Lock() - registrar.registrationResponse = &ClientRegistrationResponse{ - ClientID: "original-client-id", - RegistrationAccessToken: "access-token", - RegistrationClientURI: server.URL + "/register/client123", - } - registrar.mu.Unlock() - - // Read registration - ctx := context.Background() - resp, err := registrar.ReadClientRegistration(ctx) - - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - if resp.ClientID != "read-client-id" { - t.Errorf("Read ClientID mismatch: got %s", resp.ClientID) - } -} - -// TestOperationsWithoutCachedResponse tests error handling when no cached response exists -func TestOperationsWithoutCachedResponse(t *testing.T) { - dcrConfig := &DynamicClientRegistrationConfig{Enabled: true} - - registrar := NewDynamicClientRegistrar( - &http.Client{}, - NewLogger("DEBUG"), - dcrConfig, - "https://example.com", - ) - - ctx := context.Background() - - // Test Update without cached response - _, err := registrar.UpdateClientRegistration(ctx) - if err == nil || !stringContains(err.Error(), "no existing registration") { - t.Errorf("Update should fail without cached response: %v", err) - } - - // Test Read without cached response - _, err = registrar.ReadClientRegistration(ctx) - if err == nil || !stringContains(err.Error(), "no existing registration") { - t.Errorf("Read should fail without cached response: %v", err) - } - - // Test Delete without cached response - err = registrar.DeleteClientRegistration(ctx) - if err == nil || !stringContains(err.Error(), "no existing registration") { - t.Errorf("Delete should fail without cached response: %v", err) - } -} - -// TestOperationsWithoutManagementCredentials tests error handling without management URIs -func TestOperationsWithoutManagementCredentials(t *testing.T) { - dcrConfig := &DynamicClientRegistrationConfig{Enabled: true} - - registrar := NewDynamicClientRegistrar( - &http.Client{}, - NewLogger("DEBUG"), - dcrConfig, - "https://example.com", - ) - - // Set up cached response WITHOUT management credentials - registrar.mu.Lock() - registrar.registrationResponse = &ClientRegistrationResponse{ - ClientID: "test-client-id", - // Missing RegistrationAccessToken and RegistrationClientURI - } - registrar.mu.Unlock() - - ctx := context.Background() - - // Test Update without management credentials - _, err := registrar.UpdateClientRegistration(ctx) - if err == nil || !stringContains(err.Error(), "registration management not supported") { - t.Errorf("Update should fail without management credentials: %v", err) - } - - // Test Read without management credentials - _, err = registrar.ReadClientRegistration(ctx) - if err == nil || !stringContains(err.Error(), "registration management not supported") { - t.Errorf("Read should fail without management credentials: %v", err) - } - - // Test Delete without management credentials - err = registrar.DeleteClientRegistration(ctx) - if err == nil || !stringContains(err.Error(), "registration management not supported") { - t.Errorf("Delete should fail without management credentials: %v", err) - } -} - // stringContains is a helper function to check if a string contains a substring func stringContains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContainsHelper(s, substr)) diff --git a/error_recovery.go b/error_recovery.go index 0ffef7d..8b2231a 100644 --- a/error_recovery.go +++ b/error_recovery.go @@ -539,10 +539,10 @@ func (re *RetryExecutor) isRetryableError(err error) bool { return true } - errStr := err.Error() + errStr := strings.ToLower(err.Error()) for _, retryableErr := range re.config.RetryableErrors { - if contains(errStr, retryableErr) { + if contains(errStr, strings.ToLower(retryableErr)) { return true } } @@ -551,7 +551,7 @@ func (re *RetryExecutor) isRetryableError(err error) bool { if netErr.Timeout() { return true } - errStr := netErr.Error() + errStr := strings.ToLower(netErr.Error()) temporaryPatterns := []string{ "connection refused", "connection reset", @@ -859,8 +859,9 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f // isServiceDegraded checks if a service is currently degraded func (gd *GracefulDegradation) isServiceDegraded(serviceName string) bool { - gd.mutex.RLock() - defer gd.mutex.RUnlock() + // Uses a write lock because the recovery-timeout branch deletes from the map. + gd.mutex.Lock() + defer gd.mutex.Unlock() degradedTime, exists := gd.degradedServices[serviceName] if !exists { diff --git a/helpers.go b/helpers.go index be7d93e..f6c17b1 100644 --- a/helpers.go +++ b/helpers.go @@ -392,10 +392,19 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { baseURL := fmt.Sprintf("%s://%s", scheme, host) postLogoutRedirectURI := t.postLogoutRedirectURI + // localRedirect is used when there is no provider end-session endpoint and + // the plugin redirects the browser itself. It must never be an absolute URL + // derived from the request host (X-Forwarded-Host is client-controllable and + // would be an open redirect); use a host-relative path, or the operator's + // own configured absolute URL, instead. + localRedirect := "/" if postLogoutRedirectURI == "" { postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL) } else if !strings.HasPrefix(postLogoutRedirectURI, "http") { + localRedirect = normalizeLogoutPath(postLogoutRedirectURI) postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI) + } else { + localRedirect = postLogoutRedirectURI } // Read endSessionURL with RLock @@ -414,7 +423,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { return } - http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound) + http.Redirect(rw, req, localRedirect, http.StatusFound) } // BuildLogoutURL constructs a logout URL for the OIDC provider's end session endpoint. diff --git a/http_client_pool.go b/http_client_pool.go index 020981d..2d44d87 100644 --- a/http_client_pool.go +++ b/http_client_pool.go @@ -26,6 +26,10 @@ type sharedTransport struct { lastUsed time.Time transport *http.Transport refCount int + // tlsKey identifies the TLS trust settings (CA pool + InsecureSkipVerify) + // this transport was built with, so the at-limit fallback only reuses a + // transport whose TLS configuration matches the caller's. + tlsKey string } var ( @@ -53,19 +57,26 @@ func GetGlobalTransportPool() *SharedTransportPool { // GetOrCreateTransport gets or creates a shared transport with the given config func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *http.Transport { - // SECURITY FIX: Check client limit before creating new transport + // SECURITY FIX: Check client limit before creating new transport. if atomic.LoadInt32(&p.clientCount) >= p.maxClients { - // Return existing transport if limit reached - p.mu.RLock() - defer p.mu.RUnlock() + // At the client limit: only reuse a transport that was built for the + // SAME config (same TLS trust store). refCount is mutated under the + // write lock to avoid a data race, and a transport created for a + // different configuration is never handed back — doing so could apply + // the wrong (possibly verification-disabled) TLS settings to a request. + want := tlsConfigKey(config) + p.mu.Lock() + defer p.mu.Unlock() for _, shared := range p.transports { - if shared != nil && shared.transport != nil { + if shared != nil && shared.transport != nil && shared.tlsKey == want { shared.refCount++ shared.lastUsed = time.Now() return shared.transport } } - // If no transport available, return nil (caller should handle) + // No TLS-compatible transport available; return nil so the caller falls + // back to a default, certificate-verifying transport rather than one + // with a different (possibly verification-disabled) trust store. return nil } @@ -125,6 +136,7 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt transport: transport, refCount: 1, lastUsed: time.Now(), + tlsKey: tlsConfigKey(config), } return transport @@ -224,6 +236,18 @@ func (p *SharedTransportPool) configKey(config HTTPClientConfig) string { ) } +// tlsConfigKey identifies only the TLS trust settings of a config — the CA pool +// and the InsecureSkipVerify flag. Two configs with the same tlsConfigKey are +// safe to serve from the same transport even if other (non-TLS) parameters such +// as connection limits differ; configs with different TLS settings are not. +func tlsConfigKey(config HTTPClientConfig) string { + skip := "0" + if config.InsecureSkipVerify { + skip = "1" + } + return fmt.Sprintf("%p|%s", config.RootCAs, skip) +} + // Cleanup closes all transports and stops the cleanup goroutine func (p *SharedTransportPool) Cleanup() { p.mu.Lock() diff --git a/internal/cleanup/cleanup_test.go b/internal/cleanup/cleanup_test.go index 1ae4046..d2b3e9e 100644 --- a/internal/cleanup/cleanup_test.go +++ b/internal/cleanup/cleanup_test.go @@ -842,10 +842,18 @@ func TestWorkerPool_TaskPanic(t *testing.T) { t.Error("Timeout waiting for tasks") } - // Pool should still be functional - metrics := pool.GetMetrics() - if metrics["tasksFailed"].(int64) < 1 { - t.Error("Expected at least one failed task") + // tasksFailed is incremented in the worker's deferred recover(), which runs + // AFTER the panicking task's own `defer wg.Done()`. wg.Wait() above can + // therefore return before the failure is recorded — reading the counter + // immediately is a race that flakes on slow/contended CI runners. Poll until + // the failure lands (or time out). + deadline := time.Now().Add(2 * time.Second) + for pool.GetMetrics()["tasksFailed"].(int64) < 1 { + if time.Now().After(deadline) { + t.Error("Expected at least one failed task") + break + } + time.Sleep(5 * time.Millisecond) } } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 57e59cb..b6ddf68 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -155,12 +155,34 @@ func DetermineScheme(req *http.Request, forceHTTPS bool) string { // It checks X-Forwarded-Host header first (for proxy scenarios), // then falls back to req.Host. func DetermineHost(req *http.Request) string { - if host := req.Header.Get("X-Forwarded-Host"); host != "" { + if host := sanitizeForwardedHost(req.Header.Get("X-Forwarded-Host")); host != "" { return host } return req.Host } +// sanitizeForwardedHost returns a single, well-formed host from a (possibly +// comma-separated) X-Forwarded-Host header, or "" if none is usable. It takes +// only the first value and rejects whitespace and control characters, so a +// crafted header cannot inject CRLF, smuggle a second host, or otherwise poison +// the redirect URLs built from the result. +func sanitizeForwardedHost(v string) string { + if v == "" { + return "" + } + if i := strings.IndexByte(v, ','); i >= 0 { + v = v[:i] + } + v = strings.TrimSpace(v) + if v == "" { + return "" + } + if strings.IndexFunc(v, func(r rune) bool { return r < 0x20 || r == 0x7f || r == ' ' }) >= 0 { + return "" + } + return v +} + // BuildFullURL constructs a URL from scheme, host, and path components. // It handles absolute URLs (returning them as-is) and ensures paths have leading slashes. func BuildFullURL(scheme, host, path string) string { diff --git a/jwk.go b/jwk.go index b376401..1ff88a9 100644 --- a/jwk.go +++ b/jwk.go @@ -200,6 +200,22 @@ func buildParsedJWKS(jwks *JWKSet) *parsedJWKS { if k.Kid == "" { continue } + // Skip keys that are not intended for signature verification. + if k.Use != "" && k.Use != "sig" { + continue + } + if len(k.KeyOps) > 0 { + hasVerify := false + for _, op := range k.KeyOps { + if op == "verify" { + hasVerify = true + break + } + } + if !hasVerify { + continue + } + } var pub crypto.PublicKey var err error switch k.Kty { @@ -242,11 +258,11 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) // Safe to ignore: reading error body for diagnostics + body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024)) // Safe to ignore: reading error body for diagnostics return nil, fmt.Errorf("JWKS fetch failed with status %d: %s", resp.StatusCode, body) } - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("error reading JWKS response: %w", err) } diff --git a/logout.go b/logout.go index 920e1f6..b7aa6fb 100644 --- a/logout.go +++ b/logout.go @@ -134,8 +134,11 @@ func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http expectedIssuer := t.issuerURL t.metadataMu.RUnlock() - if iss != "" && iss != expectedIssuer { - t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer) + // Require a matching issuer. An empty iss must be rejected too: accepting a + // missing issuer would let an unauthenticated attacker force-logout any + // session whose sid is known by simply omitting iss. + if iss == "" || iss != expectedIssuer { + t.logger.Errorf("Front-channel logout: issuer validation failed: got %q, expected %q", iss, expectedIssuer) http.Error(rw, "Invalid issuer", http.StatusBadRequest) return } diff --git a/logout_test.go b/logout_test.go index 4a42d14..16df599 100644 --- a/logout_test.go +++ b/logout_test.go @@ -125,10 +125,14 @@ func TestFrontchannelLogoutBasic(t *testing.T) { expectedStatus: http.StatusOK, }, { - name: "Valid front-channel logout without issuer", + // Front-channel logout MUST carry a matching issuer. A request + // omitting iss is rejected so an unauthenticated attacker cannot + // force-logout a session whose sid is known by simply leaving iss + // out (audit rank 30). + name: "Missing issuer is rejected", method: http.MethodGet, queryParams: map[string]string{"sid": "session456"}, - expectedStatus: http.StatusOK, + expectedStatus: http.StatusBadRequest, }, } @@ -407,17 +411,17 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) { }) oidc := &TraefikOidc{ - next: nextHandler, - logger: NewLogger("debug"), - enableBackchannelLogout: true, - backchannelLogoutPath: "/backchannel-logout", - sessionInvalidationCache: mockCache, - clientID: "test-client", - issuerURL: "https://provider.example.com", - initComplete: make(chan struct{}), - firstRequestStarted: 1, + next: nextHandler, + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + initComplete: make(chan struct{}), + firstRequestStarted: 1, metadataRefreshStartedAtomic: 1, - logoutURLPath: "/logout", + logoutURLPath: "/logout", } close(oidc.initComplete) @@ -449,22 +453,23 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) { }) oidc := &TraefikOidc{ - next: nextHandler, - logger: NewLogger("debug"), - enableFrontchannelLogout: true, - frontchannelLogoutPath: "/frontchannel-logout", - sessionInvalidationCache: mockCache, - clientID: "test-client", - issuerURL: "https://provider.example.com", - initComplete: make(chan struct{}), - firstRequestStarted: 1, + next: nextHandler, + logger: NewLogger("debug"), + enableFrontchannelLogout: true, + frontchannelLogoutPath: "/frontchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + initComplete: make(chan struct{}), + firstRequestStarted: 1, metadataRefreshStartedAtomic: 1, - logoutURLPath: "/logout", + logoutURLPath: "/logout", } close(oidc.initComplete) - // Request to front-channel logout path with valid sid should succeed - req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session", nil) + // Request to front-channel logout path with valid sid + matching issuer + // should succeed. The issuer is now required (audit rank 30), so supply it. + req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session&iss=https://provider.example.com", nil) rw := httptest.NewRecorder() oidc.ServeHTTP(rw, req) @@ -1432,7 +1437,9 @@ func TestFrontchannelLogoutCacheControl(t *testing.T) { issuerURL: "https://provider.example.com", } - req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123", nil) + // Issuer is now required (audit rank 30); supply a matching one so the + // successful-logout cache headers can be asserted. + req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123&iss=https://provider.example.com", nil) rw := httptest.NewRecorder() oidc.handleFrontchannelLogout(rw, req) diff --git a/main.go b/main.go index de71157..7e98f69 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "fmt" "net/http" + "net/url" "os" "runtime" "strings" @@ -112,18 +113,18 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name config = CreateConfig() } - if config.SessionEncryptionKey == "" { - config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" - } - logger := NewLogger(config.LogLevel) - if len(config.SessionEncryptionKey) < minEncryptionKeyLength { - if runtime.Compiler == "yaegi" { - config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" - logger.Infof("Session encryption key is too short; using default key for analyzer") - } else { - return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength) - } + + // Fail closed on invalid configuration. Validate() enforces the security + // constraints (required fields, HTTPS-only URLs, key length, excludedURLs + // safety, rate-limit floor, audience format, ...) that were previously + // unenforced because this constructor never called it. Crucially it rejects + // an empty or too-short SessionEncryptionKey instead of silently + // substituting a public hardcoded key, which would let an attacker forge + // any session. Traefik's yaegi plugin analyzer supplies a valid key via + // .traefik.yml testData, so it passes; only misconfigured deployments fail. + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) } // Setup HTTP client caPool, err := config.loadCACertPool() @@ -235,7 +236,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name httpClient: httpClient, tokenHTTPClient: tokenHTTPClient, excludedURLs: createStringMap(config.ExcludedURLs), - allowedUserDomains: createStringMap(config.AllowedUserDomains), + allowedUserDomains: createCaseInsensitiveStringMap(config.AllowedUserDomains), allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers), allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups), initComplete: make(chan struct{}), @@ -346,7 +347,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name // Convert sessionMaxAge from seconds to duration (0 will use default 24 hours) sessionMaxAge := time.Duration(config.SessionMaxAge) * time.Second - t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger) // Safe to ignore: session manager creation with fallback to defaults + sessionManager, err := NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger) + if err != nil { + cancelFunc() + return nil, fmt.Errorf("failed to create session manager: %w", err) + } + t.sessionManager = sessionManager t.errorRecoveryManager = NewErrorRecoveryManager(t.logger) // Initialize token resilience manager with default configuration @@ -437,6 +443,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name // Add reference for this instance rm.AddReference(name) + registerLiveInstance() // Initialize metadata in a goroutine with proper tracking if t.goroutineWG != nil { @@ -514,13 +521,58 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { // Parameters: // - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints. func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { + // SSRF defense (audit ranks 3 & 4): a discovery document is attacker- + // influenced when the provider or its TLS is compromised. Reject any + // discovered endpoint pointed at a blocked address before the plugin issues + // outbound requests to it, so it can never be used to reach the cloud + // metadata service or an internal host. + allowLoopback := false + if pu, err := url.Parse(t.providerURL); err == nil { + allowLoopback = isLoopbackHost(pu.Hostname()) + } + sanitize := func(name, raw string) string { + if err := t.validateDiscoveredEndpoint(raw, allowLoopback); err != nil { + t.logger.Errorf("Ignoring discovered %s endpoint %q: %v", name, raw, err) + return "" + } + return raw + } + metadata.JWKSURL = sanitize("jwks_uri", metadata.JWKSURL) + metadata.AuthURL = sanitize("authorization", metadata.AuthURL) + metadata.TokenURL = sanitize("token", metadata.TokenURL) + metadata.RevokeURL = sanitize("revocation", metadata.RevokeURL) + metadata.EndSessionURL = sanitize("end_session", metadata.EndSessionURL) + metadata.RegistrationURL = sanitize("registration", metadata.RegistrationURL) + metadata.IntrospectionURL = sanitize("introspection", metadata.IntrospectionURL) + // The introspection request authenticates with the client secret via HTTP + // Basic, so the endpoint must live on the same host as the operator- + // configured provider; otherwise a poisoned discovery document could + // exfiltrate the client secret to an attacker-controlled host. + if metadata.IntrospectionURL != "" && t.providerURL != "" && !sameHost(metadata.IntrospectionURL, t.providerURL) { + t.logger.Errorf("Ignoring introspection endpoint %q: host does not match configured providerURL", metadata.IntrospectionURL) + metadata.IntrospectionURL = "" + } + + // Pin the discovered issuer to the operator-configured provider host. The + // issuer is the trust anchor for JWT issuer validation, so a poisoned + // discovery document advertising an attacker-chosen issuer must never be + // stored. Real providers (Google, Azure, Keycloak, Okta, Auth0) keep the + // issuer on the same host as the configured providerURL. On mismatch, leave + // issuerURL empty/unchanged so downstream issuer validation fails closed + // rather than trusting the attacker-chosen value. + discoveredIssuer := metadata.Issuer + if discoveredIssuer != "" && t.providerURL != "" && !sameHost(discoveredIssuer, t.providerURL) { + t.logger.Errorf("Ignoring discovered issuer %q: host does not match configured providerURL", discoveredIssuer) + discoveredIssuer = "" + } + t.metadataMu.Lock() t.jwksURL = metadata.JWKSURL t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery t.authURL = metadata.AuthURL t.tokenURL = metadata.TokenURL - t.issuerURL = metadata.Issuer + t.issuerURL = discoveredIssuer t.revocationURL = metadata.RevokeURL t.endSessionURL = metadata.EndSessionURL t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662) @@ -533,7 +585,7 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { // Publish the read-mostly URL bundle atomically. Hot-path readers Load // this directly instead of acquiring metadataMu.RLock per request. t.metadataSnapshot.Store(&MetadataSnapshot{ - IssuerURL: metadata.Issuer, + IssuerURL: discoveredIssuer, JWKSURL: metadata.JWKSURL, TokenURL: metadata.TokenURL, AuthURL: metadata.AuthURL, diff --git a/main_goroutine_leak_test.go b/main_goroutine_leak_test.go index bf06c98..03c5a02 100644 --- a/main_goroutine_leak_test.go +++ b/main_goroutine_leak_test.go @@ -194,6 +194,7 @@ func TestGoroutineLeakPrevention_MultipleInstances(t *testing.T) { config.SessionEncryptionKey = "test-encryption-key-32-bytes-long" config.ClientID = "test-client-id" config.ClientSecret = "test-client-secret" + config.CallbackURL = "/callback" handler, err := New(ctx, nil, config, "test") if err != nil { @@ -322,6 +323,7 @@ func TestGoroutineLeakPrevention_BackgroundTaskCleanup(t *testing.T) { config.SessionEncryptionKey = "test-encryption-key-32-bytes-long" config.ClientID = "test-client-id" config.ClientSecret = "test-client-secret" + config.CallbackURL = "/callback" handler, err := New(ctx, nil, config, "test") if err != nil { diff --git a/main_initialization_test.go b/main_initialization_test.go index 6906fa0..6260f20 100644 --- a/main_initialization_test.go +++ b/main_initialization_test.go @@ -26,38 +26,47 @@ func TestInitializeMetadata(t *testing.T) { name: "successful metadata initialization", providerURL: "", setupMock: func() *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Issuer must share the host with providerURL (the httptest + // server), otherwise the discovery doc is rejected as poisoned + // (audit ranks 21/22). Real providers keep issuer + endpoints on + // the same host, so derive them all from the server URL. + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(ProviderMetadata{ - Issuer: "https://provider.example.com", - AuthURL: "https://provider.example.com/auth", - TokenURL: "https://provider.example.com/token", - JWKSURL: "https://provider.example.com/jwks", - RevokeURL: "https://provider.example.com/revoke", - EndSessionURL: "https://provider.example.com/logout", + Issuer: srv.URL, + AuthURL: srv.URL + "/auth", + TokenURL: srv.URL + "/token", + JWKSURL: srv.URL + "/jwks", + RevokeURL: srv.URL + "/revoke", + EndSessionURL: srv.URL + "/logout", }) } else { w.WriteHeader(http.StatusNotFound) } })) + return srv }, validateFunc: func(t *testing.T, oidc *TraefikOidc) { - if oidc.authURL != "https://provider.example.com/auth" { + if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") { t.Errorf("expected authURL to be set, got %s", oidc.authURL) } - if oidc.tokenURL != "https://provider.example.com/token" { + if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") { t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL) } - if oidc.jwksURL != "https://provider.example.com/jwks" { + if oidc.jwksURL == "" || !strings.HasSuffix(oidc.jwksURL, "/jwks") { t.Errorf("expected jwksURL to be set, got %s", oidc.jwksURL) } - if oidc.revocationURL != "https://provider.example.com/revoke" { + if oidc.revocationURL == "" || !strings.HasSuffix(oidc.revocationURL, "/revoke") { t.Errorf("expected revocationURL to be set, got %s", oidc.revocationURL) } - if oidc.endSessionURL != "https://provider.example.com/logout" { + if oidc.endSessionURL == "" || !strings.HasSuffix(oidc.endSessionURL, "/logout") { t.Errorf("expected endSessionURL to be set, got %s", oidc.endSessionURL) } + if oidc.issuerURL == "" { + t.Errorf("expected issuerURL to be pinned to provider host, got empty") + } }, wantPanic: false, }, @@ -116,24 +125,27 @@ func TestInitializeMetadata(t *testing.T) { name: "partial metadata response", providerURL: "", setupMock: func() *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Issuer host must match providerURL (audit ranks 21/22). + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") { w.Header().Set("Content-Type", "application/json") // Only return some fields json.NewEncoder(w).Encode(map[string]string{ - "issuer": "https://partial.example.com", - "authorization_endpoint": "https://partial.example.com/auth", - "token_endpoint": "https://partial.example.com/token", + "issuer": srv.URL, + "authorization_endpoint": srv.URL + "/auth", + "token_endpoint": srv.URL + "/token", // Missing jwks_uri, revocation_endpoint, end_session_endpoint }) } })) + return srv }, validateFunc: func(t *testing.T, oidc *TraefikOidc) { - if oidc.authURL != "https://partial.example.com/auth" { + if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") { t.Errorf("expected authURL to be set, got %s", oidc.authURL) } - if oidc.tokenURL != "https://partial.example.com/token" { + if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") { t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL) } // JWKS URL and others may be empty @@ -198,20 +210,22 @@ func TestInitializeMetadata_Concurrency(t *testing.T) { requestCount := 0 var mu sync.Mutex - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() requestCount++ mu.Unlock() if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") { w.Header().Set("Content-Type", "application/json") + // Issuer host must match providerURL (audit ranks 21/22). json.NewEncoder(w).Encode(ProviderMetadata{ - Issuer: "https://concurrent.example.com", - AuthURL: "https://concurrent.example.com/auth", - TokenURL: "https://concurrent.example.com/token", - JWKSURL: "https://concurrent.example.com/jwks", - RevokeURL: "https://concurrent.example.com/revoke", - EndSessionURL: "https://concurrent.example.com/logout", + Issuer: server.URL, + AuthURL: server.URL + "/auth", + TokenURL: server.URL + "/token", + JWKSURL: server.URL + "/jwks", + RevokeURL: server.URL + "/revoke", + EndSessionURL: server.URL + "/logout", }) } })) @@ -250,7 +264,7 @@ func TestInitializeMetadata_Concurrency(t *testing.T) { oidc.initializeMetadata(server.URL) // Verify initialization - if oidc.tokenURL != "https://concurrent.example.com/token" { + if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") { t.Errorf("expected tokenURL to be set") } }() @@ -342,17 +356,19 @@ func TestProviderDetection(t *testing.T) { // TestInitializationWaiting tests waiting for initialization to complete func TestInitializationWaiting(t *testing.T) { t.Run("wait for initialization completion", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Delay response to simulate slow initialization time.Sleep(100 * time.Millisecond) if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") { w.Header().Set("Content-Type", "application/json") + // Issuer host must match providerURL (audit ranks 21/22). json.NewEncoder(w).Encode(ProviderMetadata{ - Issuer: "https://slow.example.com", - AuthURL: "https://slow.example.com/auth", - TokenURL: "https://slow.example.com/token", - JWKSURL: "https://slow.example.com/jwks", + Issuer: server.URL, + AuthURL: server.URL + "/auth", + TokenURL: server.URL + "/token", + JWKSURL: server.URL + "/jwks", }) } })) @@ -389,7 +405,7 @@ func TestInitializationWaiting(t *testing.T) { select { case <-oidc.initComplete: // Success - if oidc.tokenURL != "https://slow.example.com/token" { + if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") { t.Error("expected tokenURL to be set after initialization") } case <-time.After(2 * time.Second): @@ -398,17 +414,19 @@ func TestInitializationWaiting(t *testing.T) { }) t.Run("multiple waiters for initialization", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Delay to ensure multiple waiters time.Sleep(50 * time.Millisecond) if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") { w.Header().Set("Content-Type", "application/json") + // Issuer host must match providerURL (audit ranks 21/22). json.NewEncoder(w).Encode(ProviderMetadata{ - Issuer: "https://multi.example.com", - AuthURL: "https://multi.example.com/auth", - TokenURL: "https://multi.example.com/token", - JWKSURL: "https://multi.example.com/jwks", + Issuer: server.URL, + AuthURL: server.URL + "/auth", + TokenURL: server.URL + "/token", + JWKSURL: server.URL + "/jwks", }) } })) @@ -453,7 +471,7 @@ func TestInitializationWaiting(t *testing.T) { select { case <-oidc.initComplete: // All waiters should see the same initialized state - if oidc.tokenURL != "https://multi.example.com/token" { + if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") { t.Errorf("waiter %d: expected tokenURL to be set", id) } case <-time.After(2 * time.Second): diff --git a/main_test.go b/main_test.go index 58e487c..0f7141e 100644 --- a/main_test.go +++ b/main_test.go @@ -1875,14 +1875,14 @@ func TestHandleLogout(t *testing.T) { }, endSessionURL: "", expectedStatus: http.StatusFound, - expectedURL: "http://example.com/", + expectedURL: "/", host: "test-host", }, { name: "Logout with empty session", setupSession: func(session *SessionData) {}, expectedStatus: http.StatusFound, - expectedURL: "http://example.com/", + expectedURL: "/", host: "test-host", }, { @@ -2349,19 +2349,22 @@ func TestMultipleMiddlewareInstances(t *testing.T) { t.Skip("Skipping test in short mode") } - // Create mock provider metadata server - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Create mock provider metadata server. Issuer + endpoints must share the + // host with ProviderURL (the httptest server), otherwise the discovery doc + // is rejected as poisoned (audit ranks 21/22). Derive them from the server. + var mockServer *httptest.Server + mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { w.WriteHeader(http.StatusNotFound) return } metadata := ProviderMetadata{ - Issuer: "https://test-issuer.com", - AuthURL: "https://test-issuer.com/auth", - TokenURL: "https://test-issuer.com/token", - JWKSURL: "https://test-issuer.com/jwks", - RevokeURL: "https://test-issuer.com/revoke", - EndSessionURL: "https://test-issuer.com/end-session", + Issuer: mockServer.URL, + AuthURL: mockServer.URL + "/auth", + TokenURL: mockServer.URL + "/token", + JWKSURL: mockServer.URL + "/jwks", + RevokeURL: mockServer.URL + "/revoke", + EndSessionURL: mockServer.URL + "/end-session", } json.NewEncoder(w).Encode(metadata) })) @@ -2374,6 +2377,7 @@ func TestMultipleMiddlewareInstances(t *testing.T) { ClientSecret: "test-secret", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", + RateLimit: 100, } // Create multiple middleware instances @@ -2414,18 +2418,20 @@ func TestMultipleMiddlewareInstances(t *testing.T) { t.Fatalf("Middleware instance %d failed to initialize", i) } - // Verify each instance has its own unique configuration - if m.issuerURL != "https://test-issuer.com" { - t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL) + // Verify each instance has its own unique configuration. Issuer is now + // pinned to the provider host (audit ranks 21/22), so it equals the + // mock server URL rather than a fixed literal. + if m.issuerURL != mockServer.URL { + t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, mockServer.URL, m.issuerURL) } - if m.authURL != "https://test-issuer.com/auth" { - t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL) + if m.authURL != mockServer.URL+"/auth" { + t.Errorf("Instance %d: Expected auth URL %s, got %s", i, mockServer.URL+"/auth", m.authURL) } - if m.tokenURL != "https://test-issuer.com/token" { - t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL) + if m.tokenURL != mockServer.URL+"/token" { + t.Errorf("Instance %d: Expected token URL %s, got %s", i, mockServer.URL+"/token", m.tokenURL) } - if m.jwksURL != "https://test-issuer.com/jwks" { - t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL) + if m.jwksURL != mockServer.URL+"/jwks" { + t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, mockServer.URL+"/jwks", m.jwksURL) } if m.redirURLPath != routes[i]+"/callback" { t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath) @@ -2439,15 +2445,16 @@ func TestMultipleMiddlewareInstances(t *testing.T) { m.ServeHTTP(rr, req) - // Should redirect to auth URL since not authenticated + // Should redirect (302) to the auth flow since not authenticated. The + // absolute auth URL is not asserted here: with issuer pinning (audit + // ranks 21/22) the discovery host equals the httptest server host, + // which is loopback, so buildAuthURL's SSRF guard legitimately refuses + // to emit a loopback authorization URL in this test environment. The + // per-instance auth/token/jwks/issuer URLs were already verified above; + // here we only confirm each instance independently triggers a redirect. if rr.Code != http.StatusFound { t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code) } - - location := rr.Header().Get("Location") - if !strings.Contains(location, "https://test-issuer.com/auth") { - t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location) - } } } @@ -2460,33 +2467,43 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) { } // Create two mock provider metadata servers simulating different Keycloak realms - realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Issuer + endpoints must share the host with each realm's ProviderURL + // (the httptest server), otherwise the discovery doc is rejected as + // poisoned (audit ranks 21/22). Keep the distinguishing /realms/realmN + // path so the per-realm isolation assertions below still hold, but base + // the host on the server URL — which is exactly what a same-host Keycloak + // deployment looks like. + var realm1Server *httptest.Server + realm1Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { w.WriteHeader(http.StatusNotFound) return } + base := realm1Server.URL + "/realms/realm1" metadata := ProviderMetadata{ - Issuer: "https://keycloak.example.com/realms/realm1", - AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth", - TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token", - JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs", - EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout", + Issuer: base, + AuthURL: base + "/protocol/openid-connect/auth", + TokenURL: base + "/protocol/openid-connect/token", + JWKSURL: base + "/protocol/openid-connect/certs", + EndSessionURL: base + "/protocol/openid-connect/logout", } json.NewEncoder(w).Encode(metadata) })) defer realm1Server.Close() - realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var realm2Server *httptest.Server + realm2Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { w.WriteHeader(http.StatusNotFound) return } + base := realm2Server.URL + "/realms/realm2" metadata := ProviderMetadata{ - Issuer: "https://keycloak.example.com/realms/realm2", - AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth", - TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token", - JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs", - EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout", + Issuer: base, + AuthURL: base + "/protocol/openid-connect/auth", + TokenURL: base + "/protocol/openid-connect/token", + JWKSURL: base + "/protocol/openid-connect/certs", + EndSessionURL: base + "/protocol/openid-connect/logout", } json.NewEncoder(w).Encode(metadata) })) @@ -2500,6 +2517,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) { CallbackURL: "/realm1/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", CookiePrefix: "_oidc_realm1_", + RateLimit: 100, } // Config for realm2 @@ -2510,6 +2528,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) { CallbackURL: "/realm2/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", CookiePrefix: "_oidc_realm2_", + RateLimit: 100, } // Create middleware instances for both realms @@ -2608,8 +2627,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) { providerAvailable := false var mu sync.Mutex - // Create mock provider that initially fails, then becomes available - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Create mock provider that initially fails, then becomes available. + // Issuer + endpoints must share the host with ProviderURL (audit ranks + // 21/22), so derive them from the server URL. + var mockServer *httptest.Server + mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() available := providerAvailable mu.Unlock() @@ -2621,11 +2643,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) { if r.URL.Path == "/.well-known/openid-configuration" { metadata := ProviderMetadata{ - Issuer: "https://test-issuer.com", - AuthURL: "https://test-issuer.com/auth", - TokenURL: "https://test-issuer.com/token", - JWKSURL: "https://test-issuer.com/jwks", - EndSessionURL: "https://test-issuer.com/logout", + Issuer: mockServer.URL, + AuthURL: mockServer.URL + "/auth", + TokenURL: mockServer.URL + "/token", + JWKSURL: mockServer.URL + "/jwks", + EndSessionURL: mockServer.URL + "/logout", } json.NewEncoder(w).Encode(metadata) return @@ -2640,6 +2662,7 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) { ClientSecret: "test-secret", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", + RateLimit: 100, } // Create middleware while provider is unavailable @@ -4552,6 +4575,7 @@ func TestNewWithScopeAppending(t *testing.T) { CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", Scopes: tc.configScopes, + RateLimit: 100, } // Create middleware instance diff --git a/memory_leak_test.go b/memory_leak_test.go index 4dc16e8..8dbc178 100644 --- a/memory_leak_test.go +++ b/memory_leak_test.go @@ -1652,6 +1652,7 @@ func TestGoroutineLeaks(t *testing.T) { config.SessionEncryptionKey = "test-encryption-key-32-bytes-long" config.ClientID = "test-client" config.ClientSecret = "test-secret" + config.CallbackURL = "/callback" handler, err := New(context.Background(), nil, config, "test") require.NoError(t, err) diff --git a/metadata_cache.go b/metadata_cache.go index 8f278da..b384e28 100644 --- a/metadata_cache.go +++ b/metadata_cache.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "strings" "sync" @@ -141,10 +142,19 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st } var metadata ProviderMetadata - if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { + if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&metadata); err != nil { return nil, fmt.Errorf("failed to decode metadata: %w", err) } + // Pin the advertised issuer to the configured provider host. The issuer is + // the trust anchor for JWT issuer validation; rejecting a mismatch here + // ensures a poisoned discovery document advertising an attacker-chosen + // issuer is never cached or returned. Real providers (Google, Azure, + // Keycloak, Okta, Auth0) keep the issuer on the same host as providerURL. + if metadata.Issuer != "" && !sameHost(metadata.Issuer, providerURL) { + return nil, fmt.Errorf("discovery issuer %q host does not match provider %q", metadata.Issuer, providerURL) + } + // Cache for 1 hour by default if err := mc.Set(providerURL, &metadata, 1*time.Hour); err != nil { mc.logger.Errorf("Failed to cache metadata: %v", err) diff --git a/middleware.go b/middleware.go index 3ef3530..0dea547 100644 --- a/middleware.go +++ b/middleware.go @@ -472,6 +472,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // - req: The HTTP request to process. // - session: The user's session data containing tokens and claims. // - redirectURL: The callback URL for re-authentication if needed. +// // processAuthorizedRequestRS is the requestState-aware variant of // processAuthorizedRequest. It reads SessionData fields from the captured // snapshot in rs instead of calling session.GetX() (each of which acquires @@ -675,6 +676,44 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http // // Session persistence is the CALLER's responsibility — it must happen before // this function so Set-Cookie reaches the response. +// headerTemplateMaxLen bounds the length of a rendered operator-defined header +// template before it is forwarded downstream. Generous enough for an +// "Authorization: Bearer " value but small enough to reject obviously +// abusive output. Matches the input-validation default header cap (8KB). +const headerTemplateMaxLen = 8192 + +// headerClaimMaxLen returns the maximum accepted length for a claim-derived +// header value (principal identifier, group, role). Reuses the operator- +// configured identifier cap (default 256) so a single setting governs both +// auth paths; falls back to 256 when unset. +func (t *TraefikOidc) headerClaimMaxLen() int { + if t.maxIdentifierLength > 0 { + return t.maxIdentifierLength + } + return 256 +} + +// sanitizeHeaderClaimList drops any group/role value that fails claim +// sanitization (control chars, bidi-override runes, the , ; = delimiters, or an +// over-long value) and returns the surviving values. Failing closed on a bad +// entry prevents header injection and stops an embedded comma from injecting +// extra entries into the comma-joined header. headerName is used only for +// debug logging — the value is never logged. +func (t *TraefikOidc) sanitizeHeaderClaimList(values []string, headerName string) []string { + if len(values) == 0 { + return nil + } + safe := make([]string, 0, len(values)) + for _, v := range values { + if clean, ok := sanitizeHeaderClaimValue(v, t.headerClaimMaxLen()); ok { + safe = append(safe, clean) + } else { + t.logger.Debugf("Dropping %s entry: value failed claim sanitization", headerName) + } + } + return safe +} + func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Request, p *principal) { var ( groups, roles []string @@ -692,11 +731,18 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques return } if extractErr == nil { - if len(groups) > 0 { - req.Header.Set("X-User-Groups", strings.Join(groups, ",")) + // Sanitize each group/role before it is joined into a comma- + // delimited header. The cookie/session path does not otherwise + // sanitize claim-derived values (the bearer path sanitizes its + // identifier at construction), so a control char would enable + // header injection and an embedded comma would inject extra + // entries into the comma-joined header. Fail closed: drop any + // value that does not pass. + if safeGroups := t.sanitizeHeaderClaimList(groups, "X-User-Groups"); len(safeGroups) > 0 { + req.Header.Set("X-User-Groups", strings.Join(safeGroups, ",")) } - if len(roles) > 0 { - req.Header.Set("X-User-Roles", strings.Join(roles, ",")) + if safeRoles := t.sanitizeHeaderClaimList(roles, "X-User-Roles"); len(safeRoles) > 0 { + req.Header.Set("X-User-Roles", strings.Join(safeRoles, ",")) } } } @@ -717,12 +763,26 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques } } - req.Header.Set("X-Forwarded-User", p.Identifier) + // Sanitize the principal identifier before injecting it into headers. The + // bearer path already sanitizes its identifier at construction; the + // cookie/session path does not, so a claim carrying control chars, bidi- + // override runes, or , ; = could inject or spoof header content. Fail + // closed: drop the identifier header(s) rather than forward a tainted value. + safeIdentifier, identifierOK := sanitizeHeaderClaimValue(p.Identifier, t.headerClaimMaxLen()) + if identifierOK { + req.Header.Set("X-Forwarded-User", safeIdentifier) + } else { + t.logger.Debugf("Dropping X-Forwarded-User header: identifier failed claim sanitization") + } // When minimalHeaders is enabled, skip extra headers to prevent 431 errors if !t.minimalHeaders { req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI()) - req.Header.Set("X-Auth-Request-User", p.Identifier) + if identifierOK { + req.Header.Set("X-Auth-Request-User", safeIdentifier) + } else { + t.logger.Debugf("Dropping X-Auth-Request-User header: identifier failed claim sanitization") + } if p.IDToken != "" { req.Header.Set("X-Auth-Request-Token", p.IDToken) } @@ -747,8 +807,21 @@ func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Reques continue } headerValue := buf.String() + // Sanitize the rendered output: template inputs are claim-derived + // and attacker-influenceable, so reject control chars (header + // injection), bidi-override runes, the , ; = delimiters, and an + // over-long value. Fail closed by dropping the header rather than + // forwarding a tainted value. Do not log the value (it commonly + // carries the access token); log only name + reason. + if reason := headerValueReason(headerValue, headerTemplateMaxLen); reason != "" { + t.logger.Debugf("Dropping templated header %s: value failed sanitization (%s)", headerName, reason) + continue + } req.Header.Set(headerName, headerValue) - t.logger.Debugf("Set templated header %s = %s", headerName, headerValue) + // Do not log the value: templated headers commonly carry the access + // token (e.g. "Authorization: Bearer {{.AccessToken}}"), and logging + // it — even at debug — leaks credentials into logs. + t.logger.Debugf("Set templated header %s (%d bytes)", headerName, len(headerValue)) } } diff --git a/security_audit_fixes_test.go b/security_audit_fixes_test.go new file mode 100644 index 0000000..8ef8ec8 --- /dev/null +++ b/security_audit_fixes_test.go @@ -0,0 +1,404 @@ +package traefikoidc + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/gorilla/sessions" + "github.com/lukaszraczylo/traefikoidc/internal/utils" +) + +// TestRank1_SessionCookieIsEncrypted verifies that the session cookie payload is +// AES-encrypted, not merely HMAC-signed. Regression test for the audit finding +// "session cookies signed but NOT encrypted": a single key left the stored OIDC +// tokens recoverable in plaintext from the raw cookie bytes. +func TestRank1_SessionCookieIsEncrypted(t *testing.T) { + const secret = "a-sufficiently-long-session-encryption-key" + authKey, encKey := deriveCookieKeys(secret) + if len(authKey) != 64 || len(encKey) != 32 { + t.Fatalf("expected 64-byte auth key and 32-byte enc key, got %d/%d", len(authKey), len(encKey)) + } + if string(authKey) == string(encKey) { + t.Fatal("authentication and encryption keys must be independent") + } + + const marker = "SUPER-SECRET-ACCESS-TOKEN-marker-value" + + // Encode a session through the same two-key store the production code now + // builds (see NewSessionManager). + store := sessions.NewCookieStore(authKey, encKey) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + sess, err := store.New(req, "session") + if err != nil { + t.Fatalf("store.New failed: %v", err) + } + sess.Values["tok"] = marker + if err := sess.Save(req, rec); err != nil { + t.Fatalf("session save failed: %v", err) + } + + var cookie *http.Cookie + for _, c := range rec.Result().Cookies() { + if c.Name == "session" { + cookie = c + } + } + if cookie == nil { + t.Fatal("no session cookie was set") + } + + // The secret token must never appear in plaintext in the cookie value. + if strings.Contains(cookie.Value, marker) { + t.Error("marker token found in plaintext inside the session cookie value") + } + + // A store holding only the authentication key (the previous behavior) + // must NOT be able to read the encrypted cookie — proving the payload is + // genuinely encrypted, not just signed. + signedOnly := sessions.NewCookieStore(authKey) + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.AddCookie(cookie) + if _, derr := signedOnly.Get(req2, "session"); derr == nil { + t.Error("encrypted cookie should not be decodable without the encryption key") + } + + // The full two-key store round-trips correctly. + req3 := httptest.NewRequest(http.MethodGet, "/", nil) + req3.AddCookie(cookie) + rt, derr := store.Get(req3, "session") + if derr != nil { + t.Fatalf("round-trip decode failed: %v", derr) + } + if got, _ := rt.Values["tok"].(string); got != marker { + t.Errorf("round-trip mismatch: got %q want %q", got, marker) + } +} + +// TestRank2And6_InvalidConfigFailsClosed verifies that NewWithContext now calls +// Config.Validate() and fails closed on an empty or too-short session +// encryption key instead of silently substituting a public hardcoded key, and +// rejects other missing required fields. Regression test for "hardcoded default +// encryption key" + "Config.Validate() never called in production path". +func TestRank2And6_InvalidConfigFailsClosed(t *testing.T) { + base := func() *Config { + return &Config{ + ProviderURL: "https://accounts.google.com", + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", + SessionEncryptionKey: "this-is-a-valid-session-key-32b!", + RateLimit: 100, + } + } + + // Sanity: a fully valid config still constructs. + p, err := NewWithContext(context.Background(), base(), nil, "valid") + if err != nil { + t.Fatalf("valid config should construct, got: %v", err) + } + if p != nil { + p.Close() + } + + cases := []struct { + name string + mutate func(*Config) + }{ + {"empty key", func(c *Config) { c.SessionEncryptionKey = "" }}, + {"short key", func(c *Config) { c.SessionEncryptionKey = "tooshort" }}, + {"missing providerURL", func(c *Config) { c.ProviderURL = "" }}, + {"missing callbackURL", func(c *Config) { c.CallbackURL = "" }}, + {"plaintext remote providerURL", func(c *Config) { c.ProviderURL = "http://accounts.google.com" }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c := base() + tc.mutate(c) + plugin, err := NewWithContext(context.Background(), c, nil, tc.name) + if err == nil { + if plugin != nil { + plugin.Close() + } + t.Errorf("expected NewWithContext to reject config (%s), but it succeeded", tc.name) + } + }) + } +} + +// TestRank3_DiscoveredEndpointSSRFGuard verifies that endpoints from the +// provider discovery document are screened against SSRF targets before use. +func TestRank3_DiscoveredEndpointSSRFGuard(t *testing.T) { + tr := &TraefikOidc{} + + blocked := []string{ + "http://169.254.169.254/latest/meta-data/", // cloud metadata (link-local) + "http://[fe80::1]/jwks", // IPv6 link-local + "http://10.0.0.5/jwks", // private + "http://192.168.1.10/jwks", // private + "http://127.0.0.1/jwks", // loopback (allowLoopback=false) + "ftp://example.com/jwks", // disallowed scheme + } + for _, u := range blocked { + if err := tr.validateDiscoveredEndpoint(u, false); err == nil { + t.Errorf("expected discovered endpoint %q to be rejected", u) + } + } + + allowed := []string{ + "https://accounts.google.com/o/oauth2/v3/certs", + "https://www.googleapis.com/oauth2/v3/certs", // cross-domain JWKS must stay allowed + "", // empty optional endpoint + } + for _, u := range allowed { + if err := tr.validateDiscoveredEndpoint(u, false); err != nil { + t.Errorf("expected discovered endpoint %q to be allowed, got %v", u, err) + } + } + + // Loopback is allowed only when the provider itself is loopback (dev/test). + if err := tr.validateDiscoveredEndpoint("http://127.0.0.1:8080/jwks", true); err != nil { + t.Errorf("loopback endpoint should be allowed when allowLoopback=true: %v", err) + } + // Private addresses are allowed when explicitly opted in. + trPriv := &TraefikOidc{allowPrivateIPAddresses: true} + if err := trPriv.validateDiscoveredEndpoint("http://10.0.0.5/jwks", false); err != nil { + t.Errorf("private endpoint should be allowed when allowPrivateIPAddresses=true: %v", err) + } +} + +// TestRank4_IntrospectionHostPin verifies the host-equality check used to pin +// the credential-bearing introspection endpoint to the configured provider. +func TestRank4_IntrospectionHostPin(t *testing.T) { + if !sameHost("https://kc.example.com/realms/x", "https://kc.example.com/realms/x/protocol/openid-connect/token/introspect") { + t.Error("introspection on the same host as the provider should be accepted") + } + if sameHost("https://kc.example.com", "https://evil.example.net/introspect") { + t.Error("introspection on a different host must be rejected") + } + if sameHost("", "https://kc.example.com") || sameHost("https://kc.example.com", "") { + t.Error("empty URL must not be treated as a host match") + } +} + +// TestRank5_OpenRedirectNeutralized verifies the helper the callback now applies +// to the stored incoming path forces a host-relative redirect target. +func TestRank5_OpenRedirectNeutralized(t *testing.T) { + cases := map[string]string{ + "//evil.com/x": "/evil.com/x", + `/\evil.com`: "/evil.com", + "/legit/path": "/legit/path", + } + for in, want := range cases { + got := normalizeLogoutPath(in) + if got != want { + t.Errorf("normalizeLogoutPath(%q) = %q, want %q", in, got, want) + } + if strings.HasPrefix(got, "//") || strings.HasPrefix(got, `/\`) { + t.Errorf("normalizeLogoutPath(%q) = %q is still protocol-relative", in, got) + } + } +} + +// TestRank14_ExcludedURLSegmentBoundary verifies excluded-URL matching is +// anchored at path-segment boundaries and cannot be widened into a bypass. +func TestRank14_ExcludedURLSegmentBoundary(t *testing.T) { + if !pathExcluded("/public", "/public") { + t.Error("exact match should be excluded") + } + if !pathExcluded("/public/page", "/public") { + t.Error("sub-path should be excluded") + } + if pathExcluded("/publicsecret", "/public") { + t.Error("/publicsecret must NOT be excluded by /public") + } + if pathExcluded("/public-admin", "/public") { + t.Error("/public-admin must NOT be excluded by /public") + } + if !pathExcluded("/health", "/health/") { + t.Error("trailing-slash config should still match the exact path") + } + if pathExcluded("/anything", "/") { + t.Error("root exclusion must not match arbitrary paths") + } + if !pathExcluded("/", "/") { + t.Error("root exclusion should match the root path") + } +} + +// TestRank15_ForwardedHostSanitized verifies a crafted X-Forwarded-Host cannot +// inject CRLF, smuggle a second host, or otherwise poison the derived host. +func TestRank15_ForwardedHostSanitized(t *testing.T) { + mk := func(xfh string) *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://real.example.com/x", nil) + r.Host = "real.example.com" + if xfh != "" { + r.Header.Set("X-Forwarded-Host", xfh) + } + return r + } + if got := utils.DetermineHost(mk("ext.example.com")); got != "ext.example.com" { + t.Errorf("clean X-Forwarded-Host should be honored, got %q", got) + } + if got := utils.DetermineHost(mk("a.example.com, evil.com")); got != "a.example.com" { + t.Errorf("multi-value X-Forwarded-Host should use first host only, got %q", got) + } + for _, bad := range []string{"evil.com\r\nSet-Cookie: x=1", "evil.com /x", " "} { + if got := utils.DetermineHost(mk(bad)); got != "real.example.com" { + t.Errorf("malformed X-Forwarded-Host %q should fall back to req.Host, got %q", bad, got) + } + } +} + +// TestRank11_TransportPoolTLSIsolationAtLimit verifies that, once the client +// limit is reached, the transport pool reuses an existing transport only when +// its TLS settings match the caller's, and never hands back a transport built +// with different TLS trust settings. +func TestRank11_TransportPoolTLSIsolationAtLimit(t *testing.T) { + pool := &SharedTransportPool{ + transports: make(map[string]*sharedTransport), + maxConns: 20, + maxClients: 5, + } + + strict := DefaultHTTPClientConfig() // InsecureSkipVerify = false + t1 := pool.GetOrCreateTransport(strict) + if t1 == nil { + t.Fatal("expected a transport for the strict config") + } + + // Saturate the client limit so subsequent calls hit the fallback path. + atomic.StoreInt32(&pool.clientCount, pool.maxClients) + + // Same TLS settings, different (non-TLS) connection limit: safe to reuse. + sameTLS := DefaultHTTPClientConfig() + sameTLS.MaxConnsPerHost = 99 + if got := pool.GetOrCreateTransport(sameTLS); got != t1 { + t.Error("at the limit a TLS-compatible config should reuse the existing transport") + } + + // Different TLS settings (InsecureSkipVerify): must NOT reuse the strict + // transport — returning nil lets the caller fall back to a verifying default. + insecure := DefaultHTTPClientConfig() + insecure.InsecureSkipVerify = true + if got := pool.GetOrCreateTransport(insecure); got == t1 { + t.Error("at the limit a config with different TLS settings must not reuse the strict transport") + } +} + +// TestRank9_RedisFingerprint verifies divergent explicit Redis backends produce +// distinct fingerprints (used to warn about ignored cache config), while an +// absent or disabled Redis yields the empty (no-warning) fingerprint. +func TestRank9_RedisFingerprint(t *testing.T) { + if redisFingerprint(nil) != "" { + t.Error("nil config should yield an empty fingerprint") + } + if redisFingerprint(&Config{}) != "" { + t.Error("config without Redis should yield an empty fingerprint") + } + if redisFingerprint(&Config{Redis: &RedisConfig{Enabled: false, Address: "a:6379"}}) != "" { + t.Error("disabled Redis should yield an empty fingerprint") + } + a := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "a:6379", KeyPrefix: "p"}}) + b := redisFingerprint(&Config{Redis: &RedisConfig{Enabled: true, Address: "b:6379", KeyPrefix: "p"}}) + if a == "" || a == b { + t.Errorf("distinct enabled backends must produce distinct non-empty fingerprints (%q vs %q)", a, b) + } +} + +// TestRank10_TokenTypeCacheKeyNoCollision verifies that two different tokens +// sharing the same 32-character JWT header prefix are classified independently. +// The previous 32-char cache key would have collided and mis-classified them. +func TestRank10_TokenTypeCacheKeyNoCollision(t *testing.T) { + tr := &TraefikOidc{ + tokenTypeCache: NewCache(), + suppressDiagnosticLogs: true, + clientID: "client", + } + // A header prefix longer than 32 chars, shared by both tokens. + prefix := "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ" + idJWT := &JWT{Header: map[string]interface{}{}, Claims: map[string]interface{}{"nonce": "n"}} + accessJWT := &JWT{Header: map[string]interface{}{"typ": "at+jwt"}, Claims: map[string]interface{}{}} + + if !tr.detectTokenType(idJWT, prefix+".id.sig") { + t.Error("token with a nonce claim should be detected as an ID token") + } + if tr.detectTokenType(accessJWT, prefix+".access.sig") { + t.Error("access token (typ=at+jwt) must not be mis-classified as ID despite the shared 32-char prefix") + } +} + +// TestRank12_LiveInstanceCounter verifies the process-global instance counter +// that gates teardown of shared singleton tasks. +func TestRank12_LiveInstanceCounter(t *testing.T) { + start := atomic.LoadInt32(&liveInstanceCount) + registerLiveInstance() + registerLiveInstance() + if got := atomic.LoadInt32(&liveInstanceCount); got != start+2 { + t.Fatalf("expected %d live instances, got %d", start+2, got) + } + if rem := unregisterLiveInstance(); rem != start+1 { + t.Errorf("expected %d remaining, got %d", start+1, rem) + } + if rem := unregisterLiveInstance(); rem != start { + t.Errorf("expected %d remaining, got %d", start, rem) + } +} + +// TestRank13_CookieMaxAgeMatchesSessionLifetime verifies the cookie store's +// MaxAge (which bounds both the cookie Max-Age and the codec's cryptographic +// timestamp validity) is bound to the configured session lifetime rather than +// gorilla's 30-day default. +func TestRank13_CookieMaxAgeMatchesSessionLifetime(t *testing.T) { + maxAge := 2 * time.Hour + sm, err := NewSessionManager(strings.Repeat("k", 40), false, "", "", maxAge, NewLogger("error")) + if err != nil { + t.Fatalf("NewSessionManager failed: %v", err) + } + defer sm.cancel() + + cs, ok := sm.store.(*sessions.CookieStore) + if !ok { + t.Fatal("session store is not a *sessions.CookieStore") + } + if got := cs.Options.MaxAge; got != int(maxAge.Seconds()) { + t.Errorf("cookie store MaxAge = %d, want %d (bound to sessionMaxAge)", got, int(maxAge.Seconds())) + } +} + +// TestRank33And34_HeaderSanitizationDistinction verifies the two header sinks +// use the right strictness: free-form templated header VALUES (rank 34) permit +// , ; = (e.g. an opaque "Bearer " or an LDAP-DN claim) but reject CR/LF, +// bidi, and over-length; claim values joined into delimited/identifier headers +// (rank 33) additionally reject , ; =. +func TestRank33And34_HeaderSanitizationDistinction(t *testing.T) { + // Rank 34 — free-form header value. + if headerValueReason("Bearer abc=def==", 8192) != "" { + t.Error("'=' must be allowed in a free-form header value (opaque bearer token)") + } + if headerValueReason("cn=user,ou=eng;dc=x", 8192) != "" { + t.Error("',;=' must be allowed in a free-form header value (e.g. an LDAP DN claim)") + } + if headerValueReason("evil"+string(rune(13))+string(rune(10))+"Injected: 1", 8192) == "" { + t.Error("CR/LF must be rejected in a header value (injection)") + } + if headerValueReason("toolong", 3) == "" { + t.Error("over-length value must be rejected") + } + + // Rank 33 — claim value bound for a delimited/identifier header. + if _, ok := sanitizeHeaderClaimValue("admins,superadmins", 256); ok { + t.Error("a comma must be rejected in a value joined into a comma-delimited header") + } + if _, ok := sanitizeHeaderClaimValue("normal-user@example.com", 256); !ok { + t.Error("a clean identifier must pass claim sanitization") + } + if _, ok := sanitizeHeaderClaimValue("evil"+string(rune(13))+string(rune(10))+"X: 1", 256); ok { + t.Error("CR/LF must be rejected in a claim value") + } +} diff --git a/security_monitoring.go b/security_monitoring.go deleted file mode 100644 index 9a5d0cc..0000000 --- a/security_monitoring.go +++ /dev/null @@ -1,590 +0,0 @@ -package traefikoidc - -import ( - "fmt" - "net" - "net/http" - "strings" - "sync" - "time" -) - -// SecurityEventType categorizes different types of security events -// that can occur during OIDC authentication and authorization flows. -type SecurityEventType string - -// Security event types for monitoring and alerting -const ( - // AuthFailure indicates a failed authentication attempt - AuthFailure SecurityEventType = "authentication_failure" - // TokenValidFailure indicates JWT token validation failed - TokenValidFailure SecurityEventType = "token_validation_failure" - // RateLimitHit indicates rate limiting was triggered - RateLimitHit SecurityEventType = "rate_limit_hit" - // SuspiciousActivity indicates potentially malicious behavior - SuspiciousActivity SecurityEventType = "suspicious_activity" -) - -// DefaultSeverity returns the default severity level for each security event type. -// Severity levels are: low, medium, high. -func (t SecurityEventType) DefaultSeverity() string { - switch t { - case AuthFailure: - return "medium" - case TokenValidFailure: - return "medium" - case RateLimitHit: - return "low" - case SuspiciousActivity: - return "high" - default: - return "medium" - } -} - -// IPFailureType returns a string identifier for categorizing failures -// by IP address for rate limiting and blocking decisions. -func (t SecurityEventType) IPFailureType() string { - switch t { - case AuthFailure: - return "auth_failure" - case TokenValidFailure: - return "token_failure" - case SuspiciousActivity: - return "suspicious" - default: - return "general" - } -} - -// SecurityEvent represents a security-related event with comprehensive context. -// Contains timing information, IP address, user agent, request details, -// and custom event-specific data for security analysis and alerting. -type SecurityEvent struct { - // Timestamp when the event occurred - Timestamp time.Time `json:"timestamp"` - // Details contains event-specific additional information - Details map[string]interface{} `json:"details,omitempty"` - // Type categorizes the event (auth_failure, token_failure, etc.) - Type string `json:"type"` - // Severity indicates event importance (low, medium, high) - Severity string `json:"severity"` - // ClientIP is the source IP address of the request - ClientIP string `json:"client_ip"` - // UserAgent is the User-Agent header from the request - UserAgent string `json:"user_agent"` - // RequestPath is the requested URL path - RequestPath string `json:"request_path"` - // Message provides human-readable description of the event - Message string `json:"message"` -} - -// SecurityMonitor provides comprehensive security monitoring for the OIDC middleware. -// It tracks failures by IP address, detects suspicious patterns, enforces -// rate limits, and can trigger custom security event handlers. -type SecurityMonitor struct { - ipFailures map[string]*IPFailureTracker - patternDetector *SuspiciousPatternDetector - logger *Logger - cleanupTask *BackgroundTask - eventHandlers []SecurityEventHandler - config SecurityMonitorConfig - ipMutex sync.RWMutex -} - -// IPFailureTracker maintains failure statistics and blocking state for an IP address. -// Used for implementing progressive penalties and automatic IP blocking based on -// failure patterns, with support for different failure types for -// rate limiting and IP blocking decisions. -type IPFailureTracker struct { - // LastFailure timestamp of the most recent failure - LastFailure time.Time - // FirstFailure timestamp of the first failure in current window - FirstFailure time.Time - // BlockedUntil indicates when the IP block expires - BlockedUntil time.Time - // FailureTypes tracks counts by failure type - FailureTypes map[string]int64 - // FailureCount total number of failures - FailureCount int64 - // mutex protects concurrent access to tracker data - mutex sync.RWMutex - // IsBlocked indicates if this IP is currently blocked - IsBlocked bool -} - -// SuspiciousPatternDetector identifies attack patterns that may indicate coordinated threats. -// Analyzes events across multiple time windows to detect rapid failures, distributed attacks, -// and persistent attack patterns that individual IP monitoring might miss. -type SuspiciousPatternDetector struct { - // recentEvents stores recent security events for analysis - recentEvents []SecurityEvent - // shortWindow defines time frame for rapid failure detection - shortWindow time.Duration - // mediumWindow defines time frame for distributed attack detection - mediumWindow time.Duration - // longWindow defines time frame for persistent attack detection - longWindow time.Duration - // rapidFailureThreshold triggers rapid failure alerts - rapidFailureThreshold int - // distributedAttackThreshold triggers distributed attack alerts - distributedAttackThreshold int - // persistentAttackThreshold triggers persistent attack alerts - persistentAttackThreshold int - // eventsMutex protects concurrent access to events - eventsMutex sync.RWMutex -} - -// SecurityEventHandler defines the interface for processing security events. -// Implementations can log events, send alerts, update external systems, -// or trigger automated response actions. -type SecurityEventHandler interface { - // HandleSecurityEvent processes a security event - HandleSecurityEvent(event SecurityEvent) -} - -// SecurityMonitorConfig contains configuration parameters for the security monitor. -// Controls thresholds, time windows, and behavior for security monitoring. -type SecurityMonitorConfig struct { - // MaxFailuresPerIP sets the failure threshold before blocking - MaxFailuresPerIP int `json:"max_failures_per_ip"` - // FailureWindowMinutes defines the time window for counting failures - FailureWindowMinutes int `json:"failure_window_minutes"` - // BlockDurationMinutes sets how long to block an IP - BlockDurationMinutes int `json:"block_duration_minutes"` - // RapidFailureThreshold triggers rapid failure detection - RapidFailureThreshold int `json:"rapid_failure_threshold"` - // CleanupIntervalMinutes sets cleanup frequency for old data - CleanupIntervalMinutes int `json:"cleanup_interval_minutes"` - RetentionHours int `json:"retention_hours"` - EnablePatternDetection bool `json:"enable_pattern_detection"` - EnableDetailedLogging bool `json:"enable_detailed_logging"` - LogSuspiciousOnly bool `json:"log_suspicious_only"` -} - -// DefaultSecurityMonitorConfig returns a default configuration -func DefaultSecurityMonitorConfig() SecurityMonitorConfig { - return SecurityMonitorConfig{ - MaxFailuresPerIP: 10, - FailureWindowMinutes: 15, - BlockDurationMinutes: 60, - EnablePatternDetection: true, - RapidFailureThreshold: 5, - EnableDetailedLogging: true, - LogSuspiciousOnly: false, - CleanupIntervalMinutes: 30, - RetentionHours: 24, - } -} - -// NewSecurityMonitor creates a new security monitor instance -func NewSecurityMonitor(config SecurityMonitorConfig, logger *Logger) *SecurityMonitor { - sm := &SecurityMonitor{ - ipFailures: make(map[string]*IPFailureTracker), - eventHandlers: make([]SecurityEventHandler, 0), - config: config, - logger: logger, - patternDetector: NewSuspiciousPatternDetector(), - } - - sm.startCleanupRoutine() - - return sm -} - -// NewSuspiciousPatternDetector creates a new pattern detector -func NewSuspiciousPatternDetector() *SuspiciousPatternDetector { - return &SuspiciousPatternDetector{ - shortWindow: 1 * time.Minute, - mediumWindow: 5 * time.Minute, - longWindow: 15 * time.Minute, - rapidFailureThreshold: 5, - distributedAttackThreshold: 20, - persistentAttackThreshold: 50, - recentEvents: make([]SecurityEvent, 0), - } -} - -// RecordSecurityEvent is a generic method to record any type of security event -func (sm *SecurityMonitor) RecordSecurityEvent( - eventType SecurityEventType, - clientIP, userAgent, requestPath string, - message string, - details map[string]interface{}, - trackIPFailure bool) { - - event := SecurityEvent{ - Type: string(eventType), - Severity: eventType.DefaultSeverity(), - Timestamp: time.Now(), - ClientIP: clientIP, - UserAgent: userAgent, - RequestPath: requestPath, - Message: message, - Details: details, - } - - if trackIPFailure { - sm.recordIPFailure(clientIP, eventType.IPFailureType()) - } - - sm.processSecurityEvent(event) -} - -// RecordAuthenticationFailure records an authentication failure event -func (sm *SecurityMonitor) RecordAuthenticationFailure(clientIP, userAgent, requestPath, reason string, details map[string]interface{}) { - if details == nil { - details = make(map[string]interface{}) - } - details["reason"] = reason - - sm.RecordSecurityEvent( - AuthFailure, - clientIP, - userAgent, - requestPath, - fmt.Sprintf("Authentication failed: %s", reason), - details, - true, - ) -} - -// RecordTokenValidationFailure records a token validation failure -func (sm *SecurityMonitor) RecordTokenValidationFailure(clientIP, userAgent, requestPath, reason string, tokenPrefix string) { - details := map[string]interface{}{ - "reason": reason, - } - if tokenPrefix != "" { - details["token_prefix"] = tokenPrefix - } - - sm.RecordSecurityEvent( - TokenValidFailure, - clientIP, - userAgent, - requestPath, - fmt.Sprintf("Token validation failed: %s", reason), - details, - true, - ) -} - -// RecordRateLimitHit records when rate limiting is triggered -func (sm *SecurityMonitor) RecordRateLimitHit(clientIP, userAgent, requestPath string) { - details := map[string]interface{}{ - "limit_type": "token_verification", - } - - sm.RecordSecurityEvent( - RateLimitHit, - clientIP, - userAgent, - requestPath, - "Rate limit exceeded", - details, - true, - ) -} - -// RecordSuspiciousActivity records suspicious activity that doesn't fit other categories -func (sm *SecurityMonitor) RecordSuspiciousActivity(clientIP, userAgent, requestPath, activityType, description string, details map[string]interface{}) { - if details == nil { - details = make(map[string]interface{}) - } - details["activity_type"] = activityType - - sm.RecordSecurityEvent( - SuspiciousActivity, - clientIP, - userAgent, - requestPath, - fmt.Sprintf("Suspicious activity detected: %s - %s", activityType, description), - details, - true, - ) -} - -// recordIPFailure tracks failures for a specific IP address -func (sm *SecurityMonitor) recordIPFailure(clientIP, failureType string) { - sm.ipMutex.Lock() - defer sm.ipMutex.Unlock() - - tracker, exists := sm.ipFailures[clientIP] - if !exists { - tracker = &IPFailureTracker{ - FailureTypes: make(map[string]int64), - FirstFailure: time.Now(), - } - sm.ipFailures[clientIP] = tracker - } - - tracker.mutex.Lock() - defer tracker.mutex.Unlock() - - tracker.FailureCount++ - tracker.LastFailure = time.Now() - tracker.FailureTypes[failureType]++ - - windowStart := time.Now().Add(-time.Duration(sm.config.FailureWindowMinutes) * time.Minute) - if tracker.FirstFailure.After(windowStart) && tracker.FailureCount >= int64(sm.config.MaxFailuresPerIP) { - if !tracker.IsBlocked { - tracker.IsBlocked = true - tracker.BlockedUntil = time.Now().Add(time.Duration(sm.config.BlockDurationMinutes) * time.Minute) - - sm.logger.Errorf("IP %s blocked due to %d failures (types: %v)", clientIP, tracker.FailureCount, tracker.FailureTypes) - - blockEvent := SecurityEvent{ - Type: "ip_blocked", - Severity: "high", - Timestamp: time.Now(), - ClientIP: clientIP, - Message: fmt.Sprintf("IP blocked due to %d failures in %d minutes", tracker.FailureCount, sm.config.FailureWindowMinutes), - Details: map[string]interface{}{ - "failure_count": tracker.FailureCount, - "failure_types": tracker.FailureTypes, - "blocked_until": tracker.BlockedUntil, - }, - } - sm.processSecurityEvent(blockEvent) - } - } -} - -// IsIPBlocked checks if an IP address is currently blocked -func (sm *SecurityMonitor) IsIPBlocked(clientIP string) bool { - sm.ipMutex.RLock() - defer sm.ipMutex.RUnlock() - - tracker, exists := sm.ipFailures[clientIP] - if !exists { - return false - } - - tracker.mutex.RLock() - defer tracker.mutex.RUnlock() - - if tracker.IsBlocked && time.Now().Before(tracker.BlockedUntil) { - return true - } - - if tracker.IsBlocked && time.Now().After(tracker.BlockedUntil) { - tracker.IsBlocked = false - sm.logger.Infof("IP %s automatically unblocked", clientIP) - } - - return false -} - -// processSecurityEvent processes a security event through all handlers and pattern detection -func (sm *SecurityMonitor) processSecurityEvent(event SecurityEvent) { - if sm.config.EnablePatternDetection { - sm.patternDetector.AddEvent(event) - - if patterns := sm.patternDetector.DetectSuspiciousPatterns(); len(patterns) > 0 { - if len(patterns) == 1 { - sm.logger.Errorf("Suspicious pattern detected: %s", patterns[0]) - } else { - sm.logger.Errorf("Multiple suspicious patterns detected: %v", patterns) - } - - for _, pattern := range patterns { - patternEvent := SecurityEvent{ - Type: "suspicious_pattern", - Severity: "high", - Timestamp: time.Now(), - Message: fmt.Sprintf("Suspicious pattern detected: %s", pattern), - Details: map[string]interface{}{ - "pattern_type": pattern, - "trigger_event": event, - }, - } - sm.handleSecurityEvent(patternEvent) - } - } - } - - sm.handleSecurityEvent(event) -} - -// handleSecurityEvent sends the event to all registered handlers -func (sm *SecurityMonitor) handleSecurityEvent(event SecurityEvent) { - if sm.config.EnableDetailedLogging && (!sm.config.LogSuspiciousOnly || event.Severity == "high") { - sm.logger.Infof("Security Event [%s/%s]: %s (IP: %s, Path: %s)", - event.Type, event.Severity, event.Message, event.ClientIP, event.RequestPath) - } - - for _, handler := range sm.eventHandlers { - go handler.HandleSecurityEvent(event) - } -} - -// AddEventHandler adds a security event handler -func (sm *SecurityMonitor) AddEventHandler(handler SecurityEventHandler) { - sm.eventHandlers = append(sm.eventHandlers, handler) -} - -// This is kept for API compatibility but doesn't collect actual metrics -func (sm *SecurityMonitor) GetSecurityMetrics() map[string]interface{} { - return map[string]interface{}{ - "tracked_ips": 0, - } -} - -// AddEvent adds an event to the pattern detector -func (spd *SuspiciousPatternDetector) AddEvent(event SecurityEvent) { - spd.eventsMutex.Lock() - defer spd.eventsMutex.Unlock() - - spd.recentEvents = append(spd.recentEvents, event) - - cutoff := time.Now().Add(-spd.longWindow) - var filteredEvents []SecurityEvent - for _, e := range spd.recentEvents { - if e.Timestamp.After(cutoff) { - filteredEvents = append(filteredEvents, e) - } - } - spd.recentEvents = filteredEvents -} - -// DetectSuspiciousPatterns analyzes recent events for suspicious patterns -func (spd *SuspiciousPatternDetector) DetectSuspiciousPatterns() []string { - spd.eventsMutex.RLock() - defer spd.eventsMutex.RUnlock() - - var patterns []string - now := time.Now() - - ipCounts := make(map[string]int) - shortWindowStart := now.Add(-spd.shortWindow) - - for _, event := range spd.recentEvents { - if event.Timestamp.After(shortWindowStart) && - (event.Type == "authentication_failure" || event.Type == "token_validation_failure") { - ipCounts[event.ClientIP]++ - } - } - - for ip, count := range ipCounts { - if count >= spd.rapidFailureThreshold { - patterns = append(patterns, fmt.Sprintf("rapid_failures_from_ip_%s", ip)) - } - } - - mediumWindowStart := now.Add(-spd.mediumWindow) - uniqueFailingIPs := make(map[string]bool) - - for _, event := range spd.recentEvents { - if event.Timestamp.After(mediumWindowStart) && - (event.Type == "authentication_failure" || event.Type == "token_validation_failure") { - uniqueFailingIPs[event.ClientIP] = true - } - } - - if len(uniqueFailingIPs) >= spd.distributedAttackThreshold { - patterns = append(patterns, "distributed_attack_pattern") - } - - longWindowStart := now.Add(-spd.longWindow) - persistentFailures := 0 - - for _, event := range spd.recentEvents { - if event.Timestamp.After(longWindowStart) && - (event.Type == "authentication_failure" || event.Type == "token_validation_failure") { - persistentFailures++ - } - } - - if persistentFailures >= spd.persistentAttackThreshold { - patterns = append(patterns, "persistent_attack_pattern") - } - - return patterns -} - -// startCleanupRoutine starts the background cleanup routine -func (sm *SecurityMonitor) startCleanupRoutine() { - sm.cleanupTask = NewBackgroundTask( - "security-monitor-cleanup", - time.Duration(sm.config.CleanupIntervalMinutes)*time.Minute, - sm.cleanup, - sm.logger) - sm.cleanupTask.Start() -} - -// StopCleanupRoutine stops the background cleanup routine -func (sm *SecurityMonitor) StopCleanupRoutine() { - if sm.cleanupTask != nil { - sm.cleanupTask.Stop() - sm.cleanupTask = nil - } -} - -// cleanup removes old tracking data -func (sm *SecurityMonitor) cleanup() { - sm.ipMutex.Lock() - defer sm.ipMutex.Unlock() - - cutoff := time.Now().Add(-time.Duration(sm.config.RetentionHours) * time.Hour) - - for ip, tracker := range sm.ipFailures { - tracker.mutex.RLock() - shouldRemove := tracker.LastFailure.Before(cutoff) && !tracker.IsBlocked - tracker.mutex.RUnlock() - - if shouldRemove { - delete(sm.ipFailures, ip) - } - } - - sm.logger.Debugf("Security monitor cleanup completed, tracking %d IPs", len(sm.ipFailures)) -} - -// ExtractClientIP extracts the client IP from the request, considering proxy headers -func ExtractClientIP(r *http.Request) string { - if xri := r.Header.Get("X-Real-IP"); xri != "" { - if net.ParseIP(xri) != nil { - return xri - } - } - - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - ips := strings.Split(xff, ",") - if len(ips) > 0 { - ip := strings.TrimSpace(ips[0]) - if net.ParseIP(ip) != nil { - return ip - } - } - } - - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return host -} - -// LoggingSecurityEventHandler logs security events to the standard logger -type LoggingSecurityEventHandler struct { - logger *Logger -} - -// NewLoggingSecurityEventHandler creates a new logging event handler -func NewLoggingSecurityEventHandler(logger *Logger) *LoggingSecurityEventHandler { - return &LoggingSecurityEventHandler{logger: logger} -} - -// HandleSecurityEvent implements SecurityEventHandler -func (h *LoggingSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) { - switch event.Severity { - case "high": - h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP) - case "medium": - h.logger.Errorf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP) - case "low": - h.logger.Infof("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP) - default: - h.logger.Debugf("SECURITY [%s]: %s (IP: %s)", event.Type, event.Message, event.ClientIP) - } -} diff --git a/security_monitoring_test.go b/security_monitoring_test.go deleted file mode 100644 index 3179657..0000000 --- a/security_monitoring_test.go +++ /dev/null @@ -1,285 +0,0 @@ -package traefikoidc - -import ( - "net/http/httptest" - "strconv" - "testing" - "time" -) - -func TestSecurityMonitor(t *testing.T) { - config := DefaultSecurityMonitorConfig() - config.MaxFailuresPerIP = 3 - config.BlockDurationMinutes = 1 // 1 minute for testing - config.CleanupIntervalMinutes = 1 - - logger := NewLogger("debug") - monitor := NewSecurityMonitor(config, logger) - defer func() { - // Allow cleanup goroutine to finish - time.Sleep(150 * time.Millisecond) - }() - - t.Run("Record authentication failure", func(t *testing.T) { - monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "invalid credentials", nil) - - // Should not be blocked after first failure - if monitor.IsIPBlocked("192.168.1.1") { - t.Error("IP should not be blocked after first failure") - } - }) - - t.Run("IP blocked after max failures", func(t *testing.T) { - // Record multiple failures - for i := 0; i < config.MaxFailuresPerIP; i++ { - monitor.RecordAuthenticationFailure("192.168.1.2", "test-agent", "/login", "invalid credentials", nil) - } - - // Should be blocked now - if !monitor.IsIPBlocked("192.168.1.2") { - t.Error("IP should be blocked after max failures") - } - }) - - t.Run("Token validation failure", func(t *testing.T) { - // Just verify the method doesn't panic - monitor.RecordTokenValidationFailure("192.168.1.3", "test-agent", "/api", "invalid token", "abc123") - }) - - t.Run("Rate limit hit", func(t *testing.T) { - // Just verify the method doesn't panic - monitor.RecordRateLimitHit("192.168.1.4", "test-agent", "/api") - }) - - t.Run("Suspicious activity", func(t *testing.T) { - details := map[string]interface{}{"pattern": "unusual"} - // Just verify the method doesn't panic - monitor.RecordSuspiciousActivity("192.168.1.5", "test-agent", "/admin", "unusual pattern", "high frequency requests", details) - }) -} - -func TestSuspiciousPatternDetector(t *testing.T) { - detector := NewSuspiciousPatternDetector() - - t.Run("Add events and detect patterns", func(t *testing.T) { - // Add multiple events from same IP - for i := 0; i < 10; i++ { - event := SecurityEvent{ - Type: "authentication_failure", - ClientIP: "192.168.1.100", - Timestamp: time.Now(), - } - detector.AddEvent(event) - } - - patterns := detector.DetectSuspiciousPatterns() - - found := false - for _, p := range patterns { - if p == "rapid_failures_from_ip_192.168.1.100" { - found = true - break - } - } - if !found { - t.Error("Expected to detect rapid failure pattern") - } - }) - - t.Run("Detect distributed attack pattern", func(t *testing.T) { - // Add failures from many different IPs - for i := 0; i < 25; i++ { - event := SecurityEvent{ - Type: "authentication_failure", - ClientIP: "192.168.1." + strconv.Itoa(100+i), - Timestamp: time.Now(), - } - detector.AddEvent(event) - } - - patterns := detector.DetectSuspiciousPatterns() - - found := false - for _, p := range patterns { - if p == "distributed_attack_pattern" { - found = true - break - } - } - if !found { - t.Error("Expected to detect distributed attack pattern") - } - }) -} - -func TestExtractClientIP(t *testing.T) { - tests := []struct { - name string - remoteAddr string - headers map[string]string - expectedIP string - }{ - { - name: "Direct connection", - remoteAddr: "192.168.1.1:12345", - expectedIP: "192.168.1.1", - }, - { - name: "X-Forwarded-For header", - remoteAddr: "10.0.0.1:12345", - headers: map[string]string{"X-Forwarded-For": "203.0.113.1, 10.0.0.1"}, - expectedIP: "203.0.113.1", - }, - { - name: "X-Real-IP header", - remoteAddr: "10.0.0.1:12345", - headers: map[string]string{"X-Real-IP": "203.0.113.2"}, - expectedIP: "203.0.113.2", - }, - { - name: "Multiple headers - X-Real-IP takes precedence", - remoteAddr: "10.0.0.1:12345", - headers: map[string]string{ - "X-Forwarded-For": "203.0.113.1", - "X-Real-IP": "203.0.113.2", - }, - expectedIP: "203.0.113.2", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) - req.RemoteAddr = tt.remoteAddr - - for key, value := range tt.headers { - req.Header.Set(key, value) - } - - ip := ExtractClientIP(req) - if ip != tt.expectedIP { - t.Errorf("Expected IP %s, got %s", tt.expectedIP, ip) - } - }) - } -} - -func TestSecurityEventHandlers(t *testing.T) { - t.Run("Logging security event handler", func(t *testing.T) { - logger := NewLogger("debug") - handler := NewLoggingSecurityEventHandler(logger) - - event := SecurityEvent{ - Type: "authentication_failure", - ClientIP: "192.168.1.1", - Timestamp: time.Now(), - Message: "Test failure", - Severity: "medium", - } - - // Should not panic - handler.HandleSecurityEvent(event) - }) - - // Metrics security event handler test removed as part of metrics cleanup -} - -func TestSecurityMonitorEventHandlers(t *testing.T) { - config := DefaultSecurityMonitorConfig() - logger := NewLogger("debug") - monitor := NewSecurityMonitor(config, logger) - - // Add event handler with proper synchronization - handlerCalled := make(chan bool, 1) - handler := &testSecurityEventHandler{ - callback: func(event SecurityEvent) { - select { - case handlerCalled <- true: - default: - // Channel already has a value, don't block - } - }, - } - monitor.AddEventHandler(handler) - - monitor.RecordAuthenticationFailure("192.168.1.1", "test-agent", "/login", "test failure", nil) - - // Wait for event handler to be called with timeout - select { - case <-handlerCalled: - // Success - handler was called - case <-time.After(100 * time.Millisecond): - t.Error("Expected event handler to be called within timeout") - } -} - -// Test helper for security event handler -type testSecurityEventHandler struct { - callback func(SecurityEvent) -} - -func (h *testSecurityEventHandler) HandleSecurityEvent(event SecurityEvent) { - h.callback(event) -} - -func TestDefaultSecurityMonitorConfig(t *testing.T) { - config := DefaultSecurityMonitorConfig() - - if config.MaxFailuresPerIP <= 0 { - t.Error("Expected positive MaxFailuresPerIP") - } - if config.BlockDurationMinutes <= 0 { - t.Error("Expected positive BlockDurationMinutes") - } - if config.CleanupIntervalMinutes <= 0 { - t.Error("Expected positive CleanupIntervalMinutes") - } - if config.FailureWindowMinutes <= 0 { - t.Error("Expected positive FailureWindowMinutes") - } -} - -func TestSecurityMonitorCleanup(t *testing.T) { - config := DefaultSecurityMonitorConfig() - config.CleanupIntervalMinutes = 1 - config.BlockDurationMinutes = 1 - config.RetentionHours = 1 - - logger := NewLogger("debug") - monitor := NewSecurityMonitor(config, logger) - - // Block an IP - for i := 0; i < config.MaxFailuresPerIP; i++ { - monitor.RecordAuthenticationFailure("192.168.1.99", "test-agent", "/login", "test", nil) - } - - // Verify it's blocked - if !monitor.IsIPBlocked("192.168.1.99") { - t.Error("IP should be blocked") - } - - // Wait a bit and check if it gets unblocked automatically - time.Sleep(100 * time.Millisecond) - - // The IP should still be blocked since we haven't waited long enough - if !monitor.IsIPBlocked("192.168.1.99") { - t.Error("IP should still be blocked") - } -} - -func TestSecurityEventTypes(t *testing.T) { - config := DefaultSecurityMonitorConfig() - logger := NewLogger("debug") - monitor := NewSecurityMonitor(config, logger) - - // Test different event types - just verify they don't panic - monitor.RecordAuthenticationFailure("192.168.1.200", "test-agent", "/login", "invalid password", nil) - monitor.RecordTokenValidationFailure("192.168.1.200", "test-agent", "/api", "expired token", "abc123") - monitor.RecordRateLimitHit("192.168.1.200", "test-agent", "/api") - - details := map[string]interface{}{"pattern": "test"} - monitor.RecordSuspiciousActivity("192.168.1.200", "test-agent", "/admin", "unusual pattern", "multiple failed logins", details) - - // Just verify GetSecurityMetrics doesn't panic - _ = monitor.GetSecurityMetrics() -} diff --git a/session.go b/session.go index d282d1f..f34dcec 100644 --- a/session.go +++ b/session.go @@ -4,7 +4,9 @@ import ( "bytes" "compress/gzip" "context" + "crypto/hmac" "crypto/rand" + "crypto/sha256" "crypto/subtle" "encoding/base64" "encoding/hex" @@ -31,6 +33,45 @@ func constantTimeStringCompare(a, b string) bool { return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 } +// deriveCookieKeys derives an independent 64-byte HMAC authentication key and a +// 32-byte AES-256 encryption key from the operator-provided session encryption +// key using HKDF-SHA256 (RFC 5869). +// +// gorilla/securecookie only ENCRYPTS the cookie payload when a block +// (encryption) key is supplied; constructing the store with a single key leaves +// sessions signed-but-plaintext, so the OIDC access/refresh/ID tokens stored in +// the cookie are recoverable by anyone who can read the raw cookie bytes. Two +// independent keys are derived here so the cookie is both encrypted and +// authenticated. HKDF is implemented with stdlib hmac+sha256 so it runs under +// Traefik's yaegi interpreter, which may not export crypto/hkdf. +func deriveCookieKeys(secret string) (authKey, encKey []byte) { + okm := hkdfSHA256([]byte(secret), nil, []byte("traefikoidc session cookie keys v1"), 96) + return okm[:64], okm[64:96] +} + +// hkdfSHA256 performs HKDF-Extract followed by HKDF-Expand (RFC 5869) using +// HMAC-SHA256 and returns length bytes of output keying material. +func hkdfSHA256(ikm, salt, info []byte, length int) []byte { + if len(salt) == 0 { + salt = make([]byte, sha256.Size) + } + // Extract: PRK = HMAC-SHA256(salt, IKM) + ext := hmac.New(sha256.New, salt) + ext.Write(ikm) + prk := ext.Sum(nil) + // Expand: T(i) = HMAC-SHA256(PRK, T(i-1) | info | i) + var out, t []byte + for i := byte(1); len(out) < length; i++ { + exp := hmac.New(sha256.New, prk) + exp.Write(t) + exp.Write(info) + exp.Write([]byte{i}) + t = exp.Sum(nil) + out = append(out, t...) + } + return out[:length] +} + // min returns the minimum of two integers. // This is a utility function used throughout the session management code. // Parameters: @@ -118,12 +159,12 @@ var knownSessionKeys = map[string]bool{ "id_token": true, "user_identifier": true, "authenticated": true, - "csrf": true, - "nonce": true, - "code_verifier": true, - "incoming_path": true, - "created_at": true, - "redirect_count": true, + "csrf": true, + "nonce": true, + "code_verifier": true, + "incoming_path": true, + "created_at": true, + "redirect_count": true, } // compressCombinedPayload compresses the combined session payload using gzip. @@ -423,8 +464,13 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin ctx, cancel := context.WithCancel(context.Background()) + // Derive independent authentication + encryption keys so the session cookie + // is AES-256 encrypted and HMAC authenticated, not merely signed. See + // deriveCookieKeys: a single key would leave the stored tokens in plaintext. + authKey, encKey := deriveCookieKeys(encryptionKey) + sm := &SessionManager{ - store: sessions.NewCookieStore([]byte(encryptionKey)), + store: sessions.NewCookieStore(authKey, encKey), forceHTTPS: forceHTTPS, cookieDomain: cookieDomain, cookiePrefix: cookiePrefix, @@ -435,6 +481,14 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin cancel: cancel, } + // Bind the cookie codec's timestamp validity (and the cookie Max-Age) to the + // configured session lifetime instead of gorilla's 30-day default, so a + // stolen cookie is not cryptographically valid for up to 30 days regardless + // of the (possibly much shorter) configured sessionMaxAge (rank 13). + if cs, ok := sm.store.(*sessions.CookieStore); ok { + cs.MaxAge(int(sessionMaxAge.Seconds())) + } + // Initialize global memory monitoring (singleton) sm.memoryMonitor = GetGlobalTaskMemoryMonitor(logger) @@ -1566,7 +1620,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { sd.sessionMutex.Lock() sd.clearAllSessionData(r, true) - + // Release the lock before calling Save to prevent deadlock sd.sessionMutex.Unlock() diff --git a/settings.go b/settings.go index 4f0d70f..dfccd51 100644 --- a/settings.go +++ b/settings.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "net/url" "os" @@ -762,7 +763,32 @@ func validateTemplateSecure(templateStr string) error { // Returns true if the URL is valid and secure (HTTPS), false otherwise. func isValidSecureURL(s string) bool { u, err := url.Parse(s) - return err == nil && u.Scheme == "https" && u.Host != "" + if err != nil || u.Host == "" { + return false + } + if u.Scheme == "https" { + return true + } + // Permit plaintext HTTP only for loopback hosts (local development, + // in-cluster sidecar providers, tests). Loopback traffic never leaves the + // host, so it is not exposed to network MITM; remote providers must use + // HTTPS. Mirrors the RFC 8252 loopback allowance. + if u.Scheme == "http" && isLoopbackHost(u.Hostname()) { + return true + } + return false +} + +// isLoopbackHost reports whether host is "localhost" or a loopback IP literal +// (127.0.0.0/8 or ::1). +func isLoopbackHost(host string) bool { + if strings.EqualFold(host, "localhost") { + return true + } + if ip := net.ParseIP(host); ip != nil { + return ip.IsLoopback() + } + return false } // isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error"). diff --git a/singleton_resources.go b/singleton_resources.go index e7e69f4..a52f12e 100644 --- a/singleton_resources.go +++ b/singleton_resources.go @@ -106,8 +106,9 @@ func (rm *ResourceManager) GetCache(key string) interface{} { case "jwk-cache": cache = cacheManager.GetSharedJWKCache() default: - // Generic cache implementation - cache = NewGenericCache(1*time.Hour, rm.logger) + // Generic cache implementation; bind cleanup goroutine to the manager's + // shutdown channel so it exits when the ResourceManager shuts down. + cache = newGenericCacheWithOwner(1*time.Hour, rm.logger, rm.shutdownChan) } rm.caches[key] = cache @@ -263,6 +264,19 @@ func (rm *ResourceManager) cleanupInstance(instanceID string) { // This is a hook for future instance-specific cleanup needs } +// liveInstanceCount tracks the number of fully-constructed TraefikOidc plugin +// instances alive in this process. Process-global singleton tasks (such as the +// shared token-cleanup) must only be stopped when the LAST instance shuts down, +// otherwise one instance's teardown would disable them for all survivors. +var liveInstanceCount int32 + +// registerLiveInstance records a newly constructed plugin instance. +func registerLiveInstance() { atomic.AddInt32(&liveInstanceCount, 1) } + +// unregisterLiveInstance records a plugin instance shutting down and returns the +// number of instances still alive afterwards. +func unregisterLiveInstance() int32 { return atomic.AddInt32(&liveInstanceCount, -1) } + // Shutdown gracefully shuts down all managed resources func (rm *ResourceManager) Shutdown(ctx context.Context) error { var err error @@ -501,20 +515,31 @@ func (p *GoroutinePool) Shutdown(ctx context.Context) error { // GenericCache provides a simple cache implementation for testing type GenericCache struct { - data map[string]interface{} - logger *Logger - stopChan chan struct{} - ttl time.Duration - mu sync.RWMutex + data map[string]interface{} + // ownerStopChan, when non-nil, signals the cleanup goroutine to exit when + // the owning ResourceManager shuts down, so the goroutine cannot outlive it. + ownerStopChan <-chan struct{} + logger *Logger + stopChan chan struct{} + ttl time.Duration + mu sync.RWMutex } // NewGenericCache creates a new generic cache func NewGenericCache(ttl time.Duration, logger *Logger) *GenericCache { + return newGenericCacheWithOwner(ttl, logger, nil) +} + +// newGenericCacheWithOwner creates a generic cache whose cleanup goroutine also +// exits when ownerStopChan is closed (typically the ResourceManager shutdown +// channel), guaranteeing the goroutine is stoppable on shutdown. +func newGenericCacheWithOwner(ttl time.Duration, logger *Logger, ownerStopChan <-chan struct{}) *GenericCache { cache := &GenericCache{ - data: make(map[string]interface{}), - ttl: ttl, - logger: logger, - stopChan: make(chan struct{}), + data: make(map[string]interface{}), + ttl: ttl, + logger: logger, + stopChan: make(chan struct{}), + ownerStopChan: ownerStopChan, } // Start cleanup routine @@ -570,6 +595,11 @@ func (gc *GenericCache) cleanupRoutine() { gc.mu.Unlock() case <-gc.stopChan: return + case <-gc.ownerStopChan: + // Owning ResourceManager is shutting down; exit so the goroutine + // does not outlive its owner. A nil channel blocks forever, so this + // case is inert when no owner is set. + return } } } diff --git a/singleton_resources_test.go b/singleton_resources_test.go index d6634b8..faa5012 100644 --- a/singleton_resources_test.go +++ b/singleton_resources_test.go @@ -296,9 +296,12 @@ func TestContextAwareGoroutineManagement(t *testing.T) { // Create a TraefikOidc instance with context config := &Config{ - ProviderURL: mockServer.URL, - ClientID: "test-client", - ClientSecret: "test-secret", + ProviderURL: mockServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", + SessionEncryptionKey: "test-encryption-key-32-bytes-long", + RateLimit: 100, } plugin, err := NewWithContext(ctx, config, nil, "test") @@ -350,9 +353,9 @@ func TestContextAwareGoroutineManagement(t *testing.T) { initialGoroutines := runtime.NumGoroutine() configs := []Config{ - {ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"}, - {ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"}, - {ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"}, + {ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100}, + {ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100}, + {ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-32-bytes-long", RateLimit: 100}, } var plugins []*TraefikOidc @@ -432,9 +435,12 @@ func TestContextAwareGoroutineManagement(t *testing.T) { for i := 0; i < 3; i++ { ctx := context.Background() config := &Config{ - ProviderURL: mockServers[i].URL, - ClientID: fmt.Sprintf("client%d", i), - ClientSecret: fmt.Sprintf("secret%d", i), + ProviderURL: mockServers[i].URL, + ClientID: fmt.Sprintf("client%d", i), + ClientSecret: fmt.Sprintf("secret%d", i), + CallbackURL: "/callback", + SessionEncryptionKey: "test-encryption-key-32-bytes-long", + RateLimit: 100, } plugin, err := NewWithContext(ctx, config, nil, fmt.Sprintf("test-%d", i)) @@ -595,9 +601,12 @@ func TestBackwardCompatibility(t *testing.T) { t.Run("LegacyNewFunction", func(t *testing.T) { // Test that the original New function still works config := &Config{ - ProviderURL: "https://example.com", - ClientID: "test-client", - ClientSecret: "test-secret", + ProviderURL: "https://example.com", + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", + SessionEncryptionKey: "test-encryption-key-32-bytes-long", + RateLimit: 100, } handler, err := New(context.Background(), nil, config, "test") @@ -617,9 +626,12 @@ func TestBackwardCompatibility(t *testing.T) { t.Run("ExistingAPICompatibility", func(t *testing.T) { config := &Config{ - ProviderURL: "https://example.com", - ClientID: "test-client", - ClientSecret: "test-secret", + ProviderURL: "https://example.com", + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", + SessionEncryptionKey: "test-encryption-key-32-bytes-long", + RateLimit: 100, } handler, _ := New(context.Background(), nil, config, "test") diff --git a/token_introspection.go b/token_introspection.go index 3c05102..1ceb03e 100644 --- a/token_introspection.go +++ b/token_introspection.go @@ -21,13 +21,16 @@ type IntrospectionResponse struct { Username string `json:"username,omitempty"` TokenType string `json:"token_type,omitempty"` Sub string `json:"sub,omitempty"` - Aud string `json:"aud,omitempty"` - Iss string `json:"iss,omitempty"` - Jti string `json:"jti,omitempty"` - Exp int64 `json:"exp,omitempty"` - Iat int64 `json:"iat,omitempty"` - Nbf int64 `json:"nbf,omitempty"` - Active bool `json:"active"` + // Aud holds the introspection audience. Per RFC 7662 it may be a single + // string or an array of strings, so it is decoded as interface{} and + // matched with verifyAudience (which handles both shapes). + Aud interface{} `json:"aud,omitempty"` + Iss string `json:"iss,omitempty"` + Jti string `json:"jti,omitempty"` + Exp int64 `json:"exp,omitempty"` + Iat int64 `json:"iat,omitempty"` + Nbf int64 `json:"nbf,omitempty"` + Active bool `json:"active"` } // introspectToken performs OAuth 2.0 Token Introspection (RFC 7662) for an opaque token. @@ -120,7 +123,7 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err // Parse response per RFC 7662 Section 2.2 var introspectionResp IntrospectionResponse - if err := json.NewDecoder(resp.Body).Decode(&introspectionResp); err != nil { + if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&introspectionResp); err != nil { return nil, fmt.Errorf("failed to decode introspection response: %w", err) } @@ -128,6 +131,12 @@ func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, err if t.introspectionCache != nil { // Cache for a short duration or until token expiry (whichever is shorter) cacheDuration := 5 * time.Minute + // When introspection is REQUIRED, operators expect near-real-time + // revocation; cap the positive-result cache so a token revoked at the + // provider cannot keep passing for the full 5 minutes (rank 8). + if t.requireTokenIntrospection && cacheDuration > 30*time.Second { + cacheDuration = 30 * time.Second + } if introspectionResp.Exp > 0 { expTime := time.Unix(introspectionResp.Exp, 0) untilExp := time.Until(expTime) @@ -197,12 +206,18 @@ func (t *TraefikOidc) validateOpaqueToken(token string) error { } } - // Validate audience if configured - // Note: For opaque tokens, audience validation via introspection may be limited - // depending on what the introspection endpoint returns - if t.audience != "" && t.audience != t.clientID && resp.Aud != "" { - if resp.Aud != t.audience { - return fmt.Errorf("invalid audience: expected %s, got %s", t.audience, resp.Aud) + // Validate audience if configured. When a distinct API audience is + // configured (audience != clientID), the introspection response MUST carry + // a matching audience. Fail closed on a missing or mismatched aud: a token + // whose audience cannot be confirmed must not be accepted, otherwise a + // token minted for a different audience would pass. aud may be a single + // string or an array of strings (RFC 7662); verifyAudience handles both. + if t.audience != "" && t.audience != t.clientID { + if resp.Aud == nil { + return fmt.Errorf("invalid audience: expected %s, introspection response has no audience", t.audience) + } + if err := verifyAudience(resp.Aud, t.audience); err != nil { + return fmt.Errorf("invalid audience: expected %s: %w", t.audience, err) } } diff --git a/token_manager.go b/token_manager.go index bdd0b77..8ae46b0 100644 --- a/token_manager.go +++ b/token_manager.go @@ -5,6 +5,8 @@ package traefikoidc import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "io" "net/http" @@ -212,11 +214,13 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa // //nolint:gocognit,gocyclo // Complex token type detection with multiple provider-specific checks func (t *TraefikOidc) detectTokenType(jwt *JWT, token string) bool { - // Use first 32 chars of token as cache key (sufficient for uniqueness) - cacheKey := token - if len(token) > 32 { - cacheKey = token[:32] - } + // Key on a hash of the FULL token. The first 32 characters of a JWT are + // only the base64url-encoded header, which is identical for every token + // sharing the same alg+kid, so distinct tokens (e.g. an ID token and an + // access token from the same issuer) would otherwise collide on the cache + // key and be mis-classified. + sum := sha256.Sum256([]byte(token)) + cacheKey := hex.EncodeToString(sum[:]) // Check cache first if t.tokenTypeCache != nil { @@ -858,7 +862,6 @@ func (t *TraefikOidc) isAzureProvider() bool { strings.Contains(issuerURL, "login.windows.net") } - // startTokenCleanup starts background cleanup goroutines for cache maintenance. // It runs periodic cleanup of token cache, JWK cache, and session chunks. // Includes panic recovery to ensure stability. diff --git a/token_test.go b/token_test.go index d4b4dce..059eeda 100644 --- a/token_test.go +++ b/token_test.go @@ -3,7 +3,9 @@ package traefikoidc import ( "bytes" "compress/gzip" + "crypto/sha256" "encoding/base64" + "encoding/hex" "encoding/json" "fmt" "io" @@ -885,10 +887,9 @@ func TestDetectTokenTypeCaching(t *testing.T) { }, } token := "test-token-for-caching-with-enough-characters-for-key" - cacheKey := token - if len(token) > 32 { - cacheKey = token[:32] - } + // The cache key is a SHA-256 hash of the full token (collision-resistant). + sum := sha256.Sum256([]byte(token)) + cacheKey := hex.EncodeToString(sum[:]) result := tr.detectTokenType(jwt, token) if !result { @@ -911,521 +912,6 @@ func TestDetectTokenTypeCaching(t *testing.T) { } } -// ============================================================================= -// TOKEN VALIDATOR TESTS -// ============================================================================= - -func TestNewTokenValidator(t *testing.T) { - validator := NewTokenValidator(nil) - - if validator == nil { - t.Fatal("Expected non-nil token validator") - } - - if validator.logger == nil { - t.Error("Expected logger to be initialized") - } -} - -func TestNewTokenValidatorWithLogger(t *testing.T) { - logger := GetSingletonNoOpLogger() - validator := NewTokenValidator(logger) - - if validator == nil { - t.Fatal("Expected non-nil token validator") - } - - if validator.logger != logger { - t.Error("Expected provided logger to be used") - } -} - -func TestValidateTokenEmpty(t *testing.T) { - validator := NewTokenValidator(nil) - result := validator.ValidateToken("", false) - - if result.Valid { - t.Error("Expected invalid result for empty token") - } - - if result.Error == nil { - t.Error("Expected error for empty token") - } - - if !strings.Contains(result.Error.Error(), "empty") { - t.Errorf("Expected 'empty' in error, got: %v", result.Error) - } -} - -func TestValidateTokenRequireJWT(t *testing.T) { - validator := NewTokenValidator(nil) - - result := validator.ValidateToken("opaque_token_value_here", true) - - if result.Valid { - t.Error("Expected invalid result for opaque token when JWT required") - } - - if result.Error == nil { - t.Error("Expected error when JWT required but opaque token provided") - } -} - -func TestValidateJWTValidFormat(t *testing.T) { - validator := NewTokenValidator(nil) - - claims := map[string]interface{}{ - "sub": "user123", - "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), - } - - token := createTestJWTSimple(claims) - result := validator.ValidateToken(token, false) - - if !result.Valid { - t.Errorf("Expected valid result, got error: %v", result.Error) - } - - if result.TokenType != "JWT" { - t.Errorf("Expected TokenType 'JWT', got %s", result.TokenType) - } - - if result.Claims == nil { - t.Error("Expected claims to be parsed") - } - - if result.Expiry == nil { - t.Error("Expected expiry to be extracted") - } - - if result.IssuedAt == nil { - t.Error("Expected issued at to be extracted") - } -} - -func TestValidateJWTExpiredToken(t *testing.T) { - validator := NewTokenValidator(nil) - - claims := map[string]interface{}{ - "sub": "user123", - "exp": time.Now().Add(-1 * time.Hour).Unix(), - "iat": time.Now().Add(-2 * time.Hour).Unix(), - } - - token := createTestJWTSimple(claims) - result := validator.ValidateToken(token, false) - - if result.Valid { - t.Error("Expected invalid result for expired token") - } - - if result.Error == nil { - t.Error("Expected error for expired token") - } - - if !strings.Contains(result.Error.Error(), "expired") { - t.Errorf("Expected 'expired' in error, got: %v", result.Error) - } -} - -func TestValidateJWTFutureIssuedAt(t *testing.T) { - validator := NewTokenValidator(nil) - - claims := map[string]interface{}{ - "sub": "user123", - "exp": time.Now().Add(2 * time.Hour).Unix(), - "iat": time.Now().Add(10 * time.Minute).Unix(), - } - - token := createTestJWTSimple(claims) - result := validator.ValidateToken(token, false) - - if result.Valid { - t.Error("Expected invalid result for future iat") - } - - if result.Error == nil { - t.Error("Expected error for future iat") - } - - if !strings.Contains(result.Error.Error(), "future") { - t.Errorf("Expected 'future' in error, got: %v", result.Error) - } -} - -func TestValidateJWTNotBeforeClaim(t *testing.T) { - validator := NewTokenValidator(nil) - - claims := map[string]interface{}{ - "sub": "user123", - "exp": time.Now().Add(2 * time.Hour).Unix(), - "iat": time.Now().Unix(), - "nbf": time.Now().Add(1 * time.Hour).Unix(), - } - - token := createTestJWTSimple(claims) - result := validator.ValidateToken(token, false) - - if result.Valid { - t.Error("Expected invalid result for nbf in future") - } - - if result.Error == nil { - t.Error("Expected error for nbf in future") - } - - if !strings.Contains(result.Error.Error(), "not yet valid") { - t.Errorf("Expected 'not yet valid' in error, got: %v", result.Error) - } -} - -func TestValidateJWTInvalidFormat(t *testing.T) { - validator := NewTokenValidator(nil) - - tests := []struct { - name string - token string - }{ - {"single part", "eyJhbGciOiJIUzI1NiJ9"}, - {"two parts", "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0In0"}, - {"four parts", "part1.part2.part3.part4"}, - {"empty part", "eyJhbGciOiJIUzI1NiJ9..signature"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := validator.ValidateToken(tt.token, true) - - if result.Valid { - t.Error("Expected invalid result for malformed JWT") - } - - if result.Error == nil { - t.Error("Expected error for malformed JWT") - } - }) - } -} - -func TestValidateOpaqueTokenValid(t *testing.T) { - validator := NewTokenValidator(nil) - - token := "sk_live_abcdef123456GHIJKL789" - result := validator.ValidateToken(token, false) - - if !result.Valid { - t.Errorf("Expected valid result, got error: %v", result.Error) - } - - if result.TokenType != "Opaque" { - t.Errorf("Expected TokenType 'Opaque', got %s", result.TokenType) - } -} - -func TestValidateOpaqueTokenTooShort(t *testing.T) { - validator := NewTokenValidator(nil) - - token := "short" - result := validator.ValidateToken(token, false) - - if result.Valid { - t.Error("Expected invalid result for short token") - } - - if result.Error == nil { - t.Error("Expected error for short token") - } - - if !strings.Contains(result.Error.Error(), "too short") { - t.Errorf("Expected 'too short' in error, got: %v", result.Error) - } -} - -func TestValidateOpaqueTokenWithSpaces(t *testing.T) { - validator := NewTokenValidator(nil) - - token := "this token has spaces in it" - result := validator.ValidateToken(token, false) - - if result.Valid { - t.Error("Expected invalid result for token with spaces") - } - - if result.Error == nil { - t.Error("Expected error for token with spaces") - } - - if !strings.Contains(result.Error.Error(), "spaces") { - t.Errorf("Expected 'spaces' in error, got: %v", result.Error) - } -} - -func TestValidateOpaqueTokenControlCharacters(t *testing.T) { - validator := NewTokenValidator(nil) - - token := "token_with\x00control_char" - result := validator.ValidateToken(token, false) - - if result.Valid { - t.Error("Expected invalid result for token with control characters") - } - - if result.Error == nil { - t.Error("Expected error for token with control characters") - } - - if !strings.Contains(result.Error.Error(), "control character") { - t.Errorf("Expected 'control character' in error, got: %v", result.Error) - } -} - -func TestValidateOpaqueTokenInsufficientEntropy(t *testing.T) { - validator := NewTokenValidator(nil) - - token := "aaaaaabbbbbbccccccdddd" - result := validator.ValidateToken(token, false) - - if result.Valid { - t.Error("Expected invalid result for low entropy token") - } - - if result.Error == nil { - t.Error("Expected error for low entropy token") - } - - if !strings.Contains(result.Error.Error(), "entropy") { - t.Errorf("Expected 'entropy' in error, got: %v", result.Error) - } -} - -func TestIsValidBase64URL(t *testing.T) { - validator := NewTokenValidator(nil) - - tests := []struct { - name string - input string - expected bool - }{ - {"valid uppercase", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", true}, - {"valid lowercase", "abcdefghijklmnopqrstuvwxyz", true}, - {"valid numbers", "0123456789", true}, - {"valid dash", "abc-def", true}, - {"valid underscore", "abc_def", true}, - {"valid equals", "abc=", true}, - {"invalid at sign", "abc@def", false}, - {"invalid space", "abc def", false}, - {"invalid plus", "abc+def", false}, - {"invalid slash", "abc/def", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := validator.isValidBase64URL(tt.input) - if result != tt.expected { - t.Errorf("Expected %v for %s, got %v", tt.expected, tt.input, result) - } - }) - } -} - -func TestExtractTime(t *testing.T) { - validator := NewTokenValidator(nil) - - tests := []struct { - claim interface{} - name string - expected bool - }{ - {name: "float64", claim: float64(1609459200), expected: true}, - {name: "int64", claim: int64(1609459200), expected: true}, - {name: "int", claim: int(1609459200), expected: true}, - {name: "string", claim: "not a timestamp", expected: false}, - {name: "nil", claim: nil, expected: false}, - {name: "map", claim: map[string]interface{}{}, expected: false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := validator.extractTime(tt.claim) - - if tt.expected && result == nil { - t.Error("Expected non-nil time") - } - - if !tt.expected && result != nil { - t.Error("Expected nil time") - } - }) - } -} - -func TestValidateTokenSize(t *testing.T) { - validator := NewTokenValidator(nil) - - tests := []struct { - name string - token string - maxSize int - expectError bool - }{ - {"within limit", "short_token", 20, false}, - {"at limit", "exactly_twenty_c", 16, false}, - {"exceeds limit", "this_token_is_too_long", 10, true}, - {"empty token", "", 10, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validator.ValidateTokenSize(tt.token, tt.maxSize) - - if tt.expectError && err == nil { - t.Error("Expected error for oversized token") - } - - if !tt.expectError && err != nil { - t.Errorf("Expected no error, got: %v", err) - } - - if err != nil && !strings.Contains(err.Error(), "exceeds") { - t.Errorf("Expected 'exceeds' in error, got: %v", err) - } - }) - } -} - -func TestExtractClaims(t *testing.T) { - validator := NewTokenValidator(nil) - - claims := map[string]interface{}{ - "sub": "user123", - "email": "user@example.com", - "exp": float64(1609459200), - } - - token := createTestJWTSimple(claims) - extracted, err := validator.ExtractClaims(token) - - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if extracted == nil { - t.Fatal("Expected non-nil claims") - } - - if extracted["sub"] != "user123" { - t.Errorf("Expected sub 'user123', got %v", extracted["sub"]) - } - - if extracted["email"] != "user@example.com" { - t.Errorf("Expected email 'user@example.com', got %v", extracted["email"]) - } -} - -func TestExtractClaimsInvalidFormat(t *testing.T) { - validator := NewTokenValidator(nil) - - tests := []struct { - name string - token string - }{ - {"single part", "onlyonepart"}, - {"two parts", "two.parts"}, - {"four parts", "one.two.three.four"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := validator.ExtractClaims(tt.token) - - if err == nil { - t.Error("Expected error for invalid format") - } - - if !strings.Contains(err.Error(), "invalid JWT format") { - t.Errorf("Expected 'invalid JWT format' in error, got: %v", err) - } - }) - } -} - -func TestCompareTokensEqual(t *testing.T) { - validator := NewTokenValidator(nil) - - token1 := "secret_token_12345" - token2 := "secret_token_12345" - - if !validator.CompareTokens(token1, token2) { - t.Error("Expected tokens to be equal") - } -} - -func TestCompareTokensDifferent(t *testing.T) { - validator := NewTokenValidator(nil) - - token1 := "secret_token_12345" - token2 := "secret_token_54321" - - if validator.CompareTokens(token1, token2) { - t.Error("Expected tokens to be different") - } -} - -func TestCompareTokensDifferentLength(t *testing.T) { - validator := NewTokenValidator(nil) - - token1 := "short" - token2 := "much_longer_token" - - if validator.CompareTokens(token1, token2) { - t.Error("Expected tokens to be different (different lengths)") - } -} - -func TestCompareTokensEmpty(t *testing.T) { - validator := NewTokenValidator(nil) - - token1 := "" - token2 := "" - - if !validator.CompareTokens(token1, token2) { - t.Error("Expected empty tokens to be equal") - } -} - -func TestValidateTokenMaliciousPayloads(t *testing.T) { - validator := NewTokenValidator(nil) - - tests := []struct { - name string - token string - }{ - {"sql injection attempt", "'; DROP TABLE users; --"}, - {"xss attempt", ""}, - {"path traversal", "../../../etc/passwd"}, - {"null bytes", "token\x00with\x00nulls"}, - {"unicode exploit", "token\u0000\u0001\u0002"}, - {"extremely long", strings.Repeat("a", 100000)}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := validator.ValidateToken(tt.token, false) - - if result.Valid { - if result.Claims != nil { - t.Logf("Token considered valid: %s", tt.name) - } - } else { - if result.Error == nil { - t.Error("Expected error for malicious payload") - } - } - }) - } -} - // ============================================================================= // CONSOLIDATED TOKEN TESTS // ============================================================================= @@ -2098,19 +1584,3 @@ func createTokenOfSize(baseToken string, targetSize int) string { return baseToken } - -func createTestJWTSimple(claims map[string]interface{}) string { - header := map[string]interface{}{ - "alg": "HS256", - "typ": "JWT", - } - - headerJSON, _ := json.Marshal(header) - claimsJSON, _ := json.Marshal(claims) - - headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) - claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) - signature := base64.RawURLEncoding.EncodeToString([]byte("fake_signature")) - - return headerB64 + "." + claimsB64 + "." + signature -} diff --git a/token_validation_rs.go b/token_validation_rs.go index c2aa6d0..f3d9ad8 100644 --- a/token_validation_rs.go +++ b/token_validation_rs.go @@ -149,7 +149,15 @@ func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bo if rs.idToken != "" { return t.validateTokenExpiryRS(rs, rs.idToken) } - return true, false, false + // No ID token to corroborate an access token we cannot verify + // (Azure nonce-bearing Graph access tokens carry a proprietary, + // client-unverifiable signature). Do NOT authenticate on an + // unverified token: refresh if a refresh token is available, + // otherwise force re-authentication. + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true } } @@ -158,7 +166,12 @@ func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bo if rs.refreshToken != "" { return false, true, false } - return true, false, false + // Opaque access token, no ID token to corroborate it, and + // introspection was unavailable/disabled/errored (e.g. + // circuit-breaker open). There is nothing left to verify the token + // against, so fail closed and force re-authentication rather than + // trusting an unverified opaque token. + return false, false, true } if err := t.verifyToken(rs.idToken); err != nil { if strings.Contains(err.Error(), "token has expired") { diff --git a/token_validator.go b/token_validator.go deleted file mode 100644 index d70725b..0000000 --- a/token_validator.go +++ /dev/null @@ -1,263 +0,0 @@ -package traefikoidc - -import ( - "bytes" - "encoding/base64" - "fmt" - "strings" - "time" - - "github.com/lukaszraczylo/traefikoidc/internal/pool" -) - -// TokenValidator provides unified token validation functionality -type TokenValidator struct { - logger *Logger -} - -// NewTokenValidator creates a new token validator -func NewTokenValidator(logger *Logger) *TokenValidator { - if logger == nil { - logger = GetSingletonNoOpLogger() - } - return &TokenValidator{ - logger: logger, - } -} - -// TokenValidationResult contains the result of token validation -type TokenValidationResult struct { - Error error - Claims map[string]interface{} - Expiry *time.Time - IssuedAt *time.Time - TokenType string - Valid bool -} - -// ValidateToken performs comprehensive token validation -func (v *TokenValidator) ValidateToken(token string, requireJWT bool) TokenValidationResult { - result := TokenValidationResult{} - - // Basic validation - if token == "" { - result.Error = fmt.Errorf("token is empty") - return result - } - - // Check if it's a JWT or opaque token - dotCount := strings.Count(token, ".") - isJWT := dotCount == 2 - - if requireJWT && !isJWT { - result.Error = fmt.Errorf("token is not a valid JWT (found %d dots, expected 2)", dotCount) - return result - } - - if isJWT { - return v.validateJWT(token) - } else { - return v.validateOpaqueToken(token) - } -} - -// validateJWT validates a JWT token -func (v *TokenValidator) validateJWT(token string) TokenValidationResult { - result := TokenValidationResult{ - TokenType: "JWT", - } - - parts := strings.Split(token, ".") - if len(parts) != 3 { - result.Error = fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) - return result - } - - // Validate each part - for i, part := range parts { - if part == "" { - result.Error = fmt.Errorf("JWT part %d is empty", i) - return result - } - - // Check for valid base64url characters - if !v.isValidBase64URL(part) { - result.Error = fmt.Errorf("JWT part %d contains invalid base64url characters", i) - return result - } - } - - // Decode and parse claims - payload, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - result.Error = fmt.Errorf("failed to decode JWT payload: %w", err) - return result - } - - var claims map[string]interface{} - pm := pool.Get() - decoder := pm.GetJSONDecoder(bytes.NewReader(payload)) - defer pm.PutJSONDecoder(decoder) - if err := decoder.Decode(&claims); err != nil { - result.Error = fmt.Errorf("failed to parse JWT claims: %w", err) - return result - } - - result.Claims = claims - - // Extract standard claims - if exp, ok := claims["exp"]; ok { - expTime := v.extractTime(exp) - if expTime != nil { - result.Expiry = expTime - // Check if expired - if time.Now().After(*expTime) { - result.Error = fmt.Errorf("token is expired (expired at %v)", expTime.Format(time.RFC3339)) - return result - } - } - } - - if iat, ok := claims["iat"]; ok { - iatTime := v.extractTime(iat) - if iatTime != nil { - result.IssuedAt = iatTime - // Check if issued in future - if iatTime.After(time.Now().Add(5 * time.Minute)) { - result.Error = fmt.Errorf("token issued in future (iat: %v)", iatTime.Format(time.RFC3339)) - return result - } - } - } - - // Check nbf (not before) - if nbf, ok := claims["nbf"]; ok { - nbfTime := v.extractTime(nbf) - if nbfTime != nil && time.Now().Before(*nbfTime) { - result.Error = fmt.Errorf("token not yet valid (nbf: %v)", nbfTime.Format(time.RFC3339)) - return result - } - } - - result.Valid = true - return result -} - -// validateOpaqueToken validates an opaque token -func (v *TokenValidator) validateOpaqueToken(token string) TokenValidationResult { - result := TokenValidationResult{ - TokenType: "Opaque", - } - - // Check minimum length - if len(token) < 20 { - result.Error = fmt.Errorf("opaque token too short (length: %d)", len(token)) - return result - } - - // Check for spaces - if strings.Contains(token, " ") { - result.Error = fmt.Errorf("opaque token contains spaces") - return result - } - - // Check for control characters - for i, char := range token { - if char < 32 || char == 127 { - result.Error = fmt.Errorf("opaque token contains control character at position %d", i) - return result - } - } - - // Check entropy - if len(token) >= 20 { - uniqueChars := make(map[rune]bool) - for _, char := range token { - uniqueChars[char] = true - } - if len(uniqueChars) < 8 { - result.Error = fmt.Errorf("opaque token has insufficient entropy (unique chars: %d)", len(uniqueChars)) - return result - } - } - - result.Valid = true - return result -} - -// isValidBase64URL checks if a string contains only valid base64url characters -func (v *TokenValidator) isValidBase64URL(s string) bool { - for _, char := range s { - if !((char >= 'A' && char <= 'Z') || - (char >= 'a' && char <= 'z') || - (char >= '0' && char <= '9') || - char == '-' || char == '_' || char == '=') { - return false - } - } - return true -} - -// extractTime extracts a time.Time from various claim formats -func (v *TokenValidator) extractTime(claim interface{}) *time.Time { - var timestamp int64 - - switch val := claim.(type) { - case float64: - timestamp = int64(val) - case int64: - timestamp = val - case int: - timestamp = int64(val) - default: - return nil - } - - t := time.Unix(timestamp, 0) - return &t -} - -// ValidateTokenSize checks if token size is within acceptable limits -func (v *TokenValidator) ValidateTokenSize(token string, maxSize int) error { - if len(token) > maxSize { - return fmt.Errorf("token exceeds maximum size (size: %d, max: %d)", len(token), maxSize) - } - return nil -} - -// ExtractClaims extracts claims from a JWT without full validation -func (v *TokenValidator) ExtractClaims(token string) (map[string]interface{}, error) { - parts := strings.Split(token, ".") - if len(parts) != 3 { - return nil, fmt.Errorf("invalid JWT format") - } - - payload, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, fmt.Errorf("failed to decode payload: %w", err) - } - - var claims map[string]interface{} - pm := pool.Get() - decoder := pm.GetJSONDecoder(bytes.NewReader(payload)) - defer pm.PutJSONDecoder(decoder) - if err := decoder.Decode(&claims); err != nil { - return nil, fmt.Errorf("failed to parse claims: %w", err) - } - - return claims, nil -} - -// CompareTokens safely compares two tokens for equality -func (v *TokenValidator) CompareTokens(token1, token2 string) bool { - if len(token1) != len(token2) { - return false - } - - // Use constant-time comparison to prevent timing attacks - var result byte - for i := 0; i < len(token1); i++ { - result |= token1[i] ^ token2[i] - } - return result == 0 -} diff --git a/universal_cache.go b/universal_cache.go index 6161966..de9bc67 100644 --- a/universal_cache.go +++ b/universal_cache.go @@ -957,17 +957,28 @@ func (c *UniversalCache) updateLocalCache(key string, value interface{}, ttl tim } now := time.Now() + // Replace an existing entry in place: update the item and move its single + // list node to the front. Without this, a repeat populate of the same key + // (the per-request Get->backend-hit path) would PushFront a duplicate node + // and overwrite c.items[key], orphaning the previous node. Orphans inflate + // currentMemory/currentSize and, once eviction deletes the key, leave a + // Back() node whose key is absent from c.items — so evictOldest() spins + // while holding c.mu.Lock(): the 100%-CPU write-lock convoy seen in pprof. + // setLocal dedups the same way; evictOldest also guards any dangling node. + if existing, exists := c.items[key]; exists { + c.currentMemory -= existing.Size + c.lruList.Remove(existing.element) - // Replace any existing entry in place. Without this, a repeat populate of - // the same key (the per-request Get->backend-hit path at line ~359) - // PushFronts a second list node and overwrites c.items[key], orphaning the - // previous node. Orphans inflate currentMemory/currentSize and — once the - // eviction loop deletes the key — leave Back() nodes whose key is absent - // from c.items, so evictOldest() no-ops while lruList.Len()>0 stays true: - // an infinite loop while holding c.mu.Lock(), i.e. the 100%-CPU holder and - // write-lock convoy. setLocal already dedups on this path; this mirrors it. - if existing, ok := c.items[key]; ok { - c.removeItem(key, existing) + existing.Value = value + existing.Size = size + existing.ExpiresAt = now.Add(ttl) + existing.LastAccessed = now + existing.AccessCount++ + + existing.element = c.lruList.PushFront(key) + c.currentMemory += size + + return nil } item := &CacheItem{ diff --git a/url_helpers.go b/url_helpers.go index 6b077c8..c0a6511 100644 --- a/url_helpers.go +++ b/url_helpers.go @@ -19,7 +19,7 @@ import ( // - true if the URL should be excluded from authentication, false otherwise. func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { for excludedURL := range t.excludedURLs { - if strings.HasPrefix(currentRequest, excludedURL) { + if pathExcluded(currentRequest, excludedURL) { t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL) return true } @@ -27,6 +27,31 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { return false } +// pathExcluded reports whether requestPath is covered by an excluded prefix at a +// natural boundary: an exact match, a sub-path ("/public" → "/public/x"), or a +// file extension ("/favicon" → "/favicon.ico"). It deliberately does NOT match +// an unrelated sibling such as "/publicsecret", so a configured exclusion can no +// longer be widened into an authentication bypass on a different resource. +func pathExcluded(requestPath, excluded string) bool { + excluded = strings.TrimRight(excluded, "/") + if excluded == "" { + // A "/" (root) exclusion only matches the root path, not everything. + return requestPath == "" || requestPath == "/" + } + if requestPath == excluded { + return true + } + if !strings.HasPrefix(requestPath, excluded) { + return false + } + switch requestPath[len(excluded)] { + case '/', '.': + return true + default: + return false + } +} + // buildAuthURL constructs the OIDC provider authorization URL. // It builds the URL with all necessary parameters including client_id, scopes, // PKCE parameters, and provider-specific parameters for Google and Azure. @@ -288,6 +313,63 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error { return nil } +// validateDiscoveredEndpoint validates an endpoint URL obtained from the +// provider's OIDC/OAuth2 discovery document before the plugin issues any +// outbound request to it. A discovery document is attacker-influenced if the +// provider is malicious or its TLS is broken, so an unvalidated endpoint is an +// SSRF vector (e.g. jwks_uri or introspection_endpoint pointed at the cloud +// metadata service 169.254.169.254 or an internal host). +// +// Empty endpoints are allowed (they are optional). Link-local (which covers the +// 169.254.0.0/16 metadata range), multicast and unspecified addresses are +// always rejected. Private addresses are rejected unless allowPrivateIPAddresses +// is set. Loopback is rejected unless allowLoopback is true — which the caller +// sets only when the operator-configured providerURL is itself loopback (local +// development, in-cluster sidecars, tests), so production deployments pointed at +// a real provider still block loopback SSRF. +func (t *TraefikOidc) validateDiscoveredEndpoint(urlStr string, allowLoopback bool) error { + if urlStr == "" { + return nil + } + u, err := url.Parse(urlStr) + if err != nil { + return fmt.Errorf("invalid URL format: %w", err) + } + if u.Scheme != "https" && u.Scheme != "http" { + return fmt.Errorf("disallowed URL scheme: %q", u.Scheme) + } + if u.Host == "" { + return fmt.Errorf("missing host in URL") + } + if ip := net.ParseIP(u.Hostname()); ip != nil { + switch { + case ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsMulticast() || ip.IsUnspecified(): + return fmt.Errorf("endpoint host is a blocked address: %s", ip) + case ip.IsLoopback() && !allowLoopback: + return fmt.Errorf("endpoint host is a loopback address: %s", ip) + case ip.IsPrivate() && !t.allowPrivateIPAddresses: + return fmt.Errorf("endpoint host is a private address: %s", ip) + } + } + if strings.Contains(u.Path, "..") { + return fmt.Errorf("path traversal detected in URL path") + } + return nil +} + +// sameHost reports whether two URLs share the same host:port (case-insensitive). +// Used to pin the credential-bearing introspection endpoint to the operator- +// configured provider so a poisoned discovery document cannot redirect the +// client secret to an attacker-controlled host. +func sameHost(a, b string) bool { + ua, erra := url.Parse(a) + ub, errb := url.Parse(b) + if erra != nil || errb != nil || ua.Host == "" || ub.Host == "" { + return false + } + return strings.EqualFold(ua.Host, ub.Host) +} + // validateHost validates a hostname or IP address for security. // It prevents access to localhost, private networks, and known metadata endpoints. // When allowPrivateIPAddresses is enabled, private IP checks are skipped. diff --git a/utilities.go b/utilities.go index 7887649..e2d34b5 100644 --- a/utilities.go +++ b/utilities.go @@ -135,7 +135,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool { return false } - domain := parts[1] + domain := strings.ToLower(parts[1]) _, domainAllowed := t.allowedUserDomains[domain] if domainAllowed { @@ -236,8 +236,13 @@ func (t *TraefikOidc) Close() error { // Get resource manager for cleanup rm := GetResourceManager() - // Stop singleton tasks related to this instance - _ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup + // singleton-token-cleanup is a process-global task shared by every plugin + // instance. Only stop it when the LAST instance is shutting down; + // otherwise one instance's teardown (e.g. a single config reload) would + // kill chunked-session/token cleanup for all surviving instances (rank 12). + if unregisterLiveInstance() <= 0 { + _ = rm.StopBackgroundTask("singleton-token-cleanup") // best effort, last instance only + } // Stop metadata refresh task using same hash-based name as startMetadataRefresh if t.providerURL != "" { hash := sha256.Sum256([]byte(t.providerURL))