From 5fcbd549558d8215f8ac89c1508dca6b75f72eb0 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sun, 30 Nov 2025 01:41:12 +0000 Subject: [PATCH] Add sharded cache and prevention of CPU spikes / locks (#96) * Add sharded cache and prevention of CPU spikes / locks * Add dynamic client registration with oidc provider * Fix race condition introduced during the sharded cache implementation. * Add page for traefikoidc. --- .github/workflows/pr-validation.yml | 23 +- README.md | 118 ++++ config/settings.go | 83 +++ docs/index.html | 831 ++++++++++++++++++++++ dynamic_client_registration.go | 550 +++++++++++++++ dynamic_client_registration_test.go | 1002 +++++++++++++++++++++++++++ http_client_pool.go | 52 +- jwt.go | 126 +++- main.go | 70 +- main_test.go | 22 +- refresh_coordinator.go | 100 ++- refresh_coordinator_test.go | 95 +++ session_chunk_manager.go | 146 ++-- settings.go | 85 +++ sharded_cache.go | 207 ++++++ sharded_cache_test.go | 413 +++++++++++ singleton_resources.go | 59 +- singleton_resources_test.go | 279 ++++++++ token_manager.go | 19 +- types.go | 6 + universal_cache.go | 13 +- universal_cache_singleton.go | 154 +++- 22 files changed, 4262 insertions(+), 191 deletions(-) create mode 100644 docs/index.html create mode 100644 dynamic_client_registration.go create mode 100644 dynamic_client_registration_test.go create mode 100644 sharded_cache.go create mode 100644 sharded_cache_test.go diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 4d480e2..bf656a0 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -211,11 +211,6 @@ jobs: echo "coverage=$COVERAGE" >> $GITHUB_OUTPUT echo "Total Coverage: $COVERAGE%" - # Get per-package coverage - echo "## Coverage by Package" >> coverage_report.md - echo "" >> coverage_report.md - go tool cover -func=coverage.out | grep -v "total:" | awk '{print "- " $1 ": " $3}' >> coverage_report.md || true - - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: @@ -230,21 +225,19 @@ jobs: uses: actions/github-script@v8 with: script: | - const fs = require('fs'); const coverage = '${{ steps.coverage.outputs.coverage }}'; - let coverageReport = ''; - - try { - coverageReport = fs.readFileSync('coverage_report.md', 'utf8'); - } catch (e) { - coverageReport = 'Coverage details not available'; - } - const threshold = 70; const coverageNum = parseFloat(coverage); const emoji = coverageNum >= threshold ? '✅' : '⚠️'; + const status = coverageNum >= threshold ? 'meets' : 'below'; - const body = `## ${emoji} Test Coverage Report\n\n**Total Coverage:** ${coverage}%\n**Threshold:** ${threshold}%\n\n${coverageReport}`; + const body = `## ${emoji} Test Coverage Report + + | Metric | Value | + |--------|-------| + | **Total Coverage** | ${coverage}% | + | **Threshold** | ${threshold}% | + | **Status** | ${emoji} Coverage ${status} threshold |`; // Find existing comment const { data: comments } = await github.rest.issues.listComments({ diff --git a/README.md b/README.md index 919ffb2..bdee348 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit - **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more - **Automatic provider detection**: Automatically detects and configures provider-specific settings +- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration - **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes - **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles - **Domain restrictions**: Limit access to specific email domains or individual users @@ -552,6 +553,123 @@ spec: **Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, set to `true` and consider implementing a shared cache backend (Redis/Memcached) if replay detection is required. +## Dynamic Client Registration (RFC 7591) + +The middleware supports **OIDC Dynamic Client Registration** (RFC 7591), allowing automatic client registration with OIDC providers without manual pre-registration. This is useful for: + +- **Multi-tenant deployments**: Automatically register clients per tenant +- **Development environments**: Quick setup without manual OAuth app creation +- **Self-service integrations**: Allow applications to self-register + +### How It Works + +1. When enabled, the middleware discovers the `registration_endpoint` from the provider's `.well-known/openid-configuration` +2. If no `clientID` is configured, it automatically registers a new client with the provider +3. The registered `client_id` and `client_secret` are cached and optionally persisted to a file +4. Subsequent requests use the registered credentials + +### Configuration + +```yaml +apiVersion: traefik.io/v1alpha1 +kind: Middleware +metadata: + name: oidc-dynamic-registration + namespace: traefik +spec: + plugin: + traefikoidc: + providerURL: https://your-oidc-provider.com + # clientID and clientSecret are NOT required when using DCR + sessionEncryptionKey: your-secure-encryption-key-min-32-chars + callbackURL: /oauth2/callback + + dynamicClientRegistration: + enabled: true + + # Optional: Initial access token for protected registration endpoints + initialAccessToken: "your-initial-access-token" + + # Optional: Override the registration endpoint (auto-discovered by default) + registrationEndpoint: "https://your-provider.com/register" + + # Optional: Persist credentials to file for reuse across restarts + persistCredentials: true + credentialsFile: "/tmp/oidc-client-credentials.json" + + # Client metadata for registration + clientMetadata: + redirect_uris: + - "https://your-app.com/oauth2/callback" + client_name: "My Application" + application_type: "web" + grant_types: + - "authorization_code" + - "refresh_token" + response_types: + - "code" + token_endpoint_auth_method: "client_secret_basic" + contacts: + - "admin@your-app.com" +``` + +### DCR Configuration Parameters + +| Parameter | Description | Required | Default | +|-----------|-------------|----------|---------| +| `enabled` | Enable dynamic client registration | Yes | `false` | +| `initialAccessToken` | Bearer token for protected registration endpoints | No | - | +| `registrationEndpoint` | Override auto-discovered registration endpoint | No | From discovery | +| `persistCredentials` | Save registered credentials to file | No | `false` | +| `credentialsFile` | Path to store/load credentials | No | `/tmp/oidc-client-credentials.json` | +| `clientMetadata.redirect_uris` | **REQUIRED** - Redirect URIs for OAuth flow | Yes | - | +| `clientMetadata.client_name` | Human-readable client name | No | - | +| `clientMetadata.application_type` | `web` or `native` | No | `web` | +| `clientMetadata.grant_types` | OAuth grant types | No | `["authorization_code", "refresh_token"]` | +| `clientMetadata.response_types` | OAuth response types | No | `["code"]` | +| `clientMetadata.token_endpoint_auth_method` | Authentication method | No | `client_secret_basic` | +| `clientMetadata.contacts` | Contact email addresses | No | - | +| `clientMetadata.logo_uri` | URL to client logo | No | - | +| `clientMetadata.client_uri` | URL to client homepage | No | - | +| `clientMetadata.policy_uri` | URL to privacy policy | No | - | +| `clientMetadata.tos_uri` | URL to terms of service | No | - | +| `clientMetadata.scope` | Space-separated scopes | No | - | + +### Provider Support + +DCR support varies by provider: + +| Provider | DCR Support | Notes | +|----------|-------------|-------| +| Keycloak | ✅ Full | Enable in realm settings | +| Auth0 | ✅ Full | Requires Management API token | +| Okta | ✅ Full | Enable Dynamic Client Registration | +| Azure AD | ⚠️ Limited | App Registration API instead | +| Google | ❌ No | Manual registration required | +| AWS Cognito | ❌ No | Manual registration required | + +### Security Considerations + +1. **HTTPS Required**: Registration endpoints must use HTTPS (except localhost for development) +2. **Initial Access Token**: Recommended for production to prevent unauthorized registrations +3. **Credential Persistence**: If enabled, ensure the credentials file has appropriate permissions (0600) +4. **Secret Expiration**: Monitor `client_secret_expires_at` and handle rotation if needed + +### Example: Keycloak with DCR + +```yaml +dynamicClientRegistration: + enabled: true + clientMetadata: + redirect_uris: + - "https://myapp.example.com/oauth2/callback" + client_name: "My App - Production" + application_type: "web" + grant_types: + - "authorization_code" + - "refresh_token" +``` + ## Usage Examples ### Basic Configuration diff --git a/config/settings.go b/config/settings.go index ae780e3..aa80724 100644 --- a/config/settings.go +++ b/config/settings.go @@ -69,6 +69,89 @@ type Config struct { HTTPClient *http.Client `json:"-"` CookieDomain string `json:"cookieDomain"` SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"` + + // Dynamic Client Registration (RFC 7591) configuration + DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"` +} + +// DynamicClientRegistrationConfig configures OIDC Dynamic Client Registration (RFC 7591) +type DynamicClientRegistrationConfig struct { + // Enabled enables automatic client registration with the OIDC provider + Enabled bool `json:"enabled"` + + // InitialAccessToken is an optional bearer token for protected registration endpoints + // Some providers require this token to authorize new client registrations + InitialAccessToken string `json:"initialAccessToken,omitempty"` + + // RegistrationEndpoint overrides the endpoint discovered from provider metadata + // If empty, uses the registration_endpoint from .well-known/openid-configuration + RegistrationEndpoint string `json:"registrationEndpoint,omitempty"` + + // ClientMetadata contains the client metadata to register + ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"` + + // PersistCredentials determines whether to save registered credentials to a file + // This allows reusing the same client_id/client_secret across restarts + PersistCredentials bool `json:"persistCredentials"` + + // CredentialsFile is the path to store/load registered client credentials + // Defaults to "/tmp/oidc-client-credentials.json" if not specified + CredentialsFile string `json:"credentialsFile,omitempty"` +} + +// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591) +type ClientRegistrationMetadata struct { + // RedirectURIs is REQUIRED - array of redirect URIs for authorization + RedirectURIs []string `json:"redirect_uris"` + + // ResponseTypes specifies OAuth 2.0 response types (default: ["code"]) + ResponseTypes []string `json:"response_types,omitempty"` + + // GrantTypes specifies OAuth 2.0 grant types (default: ["authorization_code"]) + GrantTypes []string `json:"grant_types,omitempty"` + + // ApplicationType is either "web" (default) or "native" + ApplicationType string `json:"application_type,omitempty"` + + // Contacts is an array of email addresses for responsible parties + Contacts []string `json:"contacts,omitempty"` + + // ClientName is a human-readable name for the client + ClientName string `json:"client_name,omitempty"` + + // LogoURI is a URL pointing to a logo for the client + LogoURI string `json:"logo_uri,omitempty"` + + // ClientURI is a URL of the home page of the client + ClientURI string `json:"client_uri,omitempty"` + + // PolicyURI is a URL pointing to the client's privacy policy + PolicyURI string `json:"policy_uri,omitempty"` + + // TOSURI is a URL pointing to the client's terms of service + TOSURI string `json:"tos_uri,omitempty"` + + // JWKSURI is a URL for the client's JSON Web Key Set + JWKSURI string `json:"jwks_uri,omitempty"` + + // SubjectType is "pairwise" or "public" (provider-specific) + SubjectType string `json:"subject_type,omitempty"` + + // TokenEndpointAuthMethod specifies how the client authenticates at token endpoint + // Values: "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", "none" + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + + // DefaultMaxAge is the default maximum authentication age in seconds + DefaultMaxAge int `json:"default_max_age,omitempty"` + + // RequireAuthTime specifies whether auth_time claim is required in ID token + RequireAuthTime bool `json:"require_auth_time,omitempty"` + + // DefaultACRValues specifies default ACR values + DefaultACRValues []string `json:"default_acr_values,omitempty"` + + // Scope is a space-separated list of scopes (alternative to config.Scopes) + Scope string `json:"scope,omitempty"` } // HeaderConfig represents header template configuration diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 0000000..d22eafa --- /dev/null +++ b/docs/index.html @@ -0,0 +1,831 @@ + + + + + + Traefik OIDC - OpenID Connect Authentication Middleware + + + + + + + + + + + + + + + +
+
+
+
+
+ +
+
+
+
+ +
+
+

+ OpenID Connect for
Traefik +

+

+ Production-ready OIDC authentication middleware. Drop-in replacement for oauth2-proxy and forward-auth with support for 9+ identity providers. +

+ +
+ Version + License + Go Report + Coverage +
+
+
+
+ + +
+
+
+

Features

+

Enterprise-grade authentication for your Traefik deployments

+
+
+
+
+
+ +
+
+

Universal Provider Support

+

Works with Google, Azure AD, Auth0, Okta, Keycloak, Cognito, GitLab, and any OIDC-compliant provider

+
+
+
+
+
+
+ +
+
+

Auto-Detection

+

Automatically detects and configures provider-specific settings from OIDC discovery

+
+
+
+
+
+
+ +
+
+

Dynamic Registration

+

RFC 7591 Dynamic Client Registration for automatic client setup without manual configuration

+
+
+
+
+
+
+ +
+
+

Automatic Scope Filtering

+

Intelligently filters OAuth scopes based on provider capabilities from discovery documents

+
+
+
+
+
+
+ +
+
+

Security Headers

+

Comprehensive security headers including CORS, CSP, HSTS, and customizable profiles

+
+
+
+
+
+
+ +
+
+

Domain & User Restrictions

+

Limit access to specific email domains, individual users, or role-based groups

+
+
+
+
+
+
+ +
+
+

Role-Based Access

+

Restrict access based on roles and groups from OIDC claims

+
+
+
+
+
+
+ +
+
+

Automatic Token Refresh

+

Secure session handling with proactive token refresh before expiry

+
+
+
+
+
+
+ +
+
+

Rate Limiting

+

Built-in protection against brute force attacks with configurable limits

+
+
+
+
+
+
+ +
+
+

Custom Headers

+

Template-based headers using OIDC claims and tokens for downstream services

+
+
+
+
+
+
+ +
+
+

PKCE Support

+

Proof Key for Code Exchange for enhanced security in authorization code flow

+
+
+
+
+
+
+ +
+
+

Memory Management

+

Bounded caches with LRU eviction, automatic cleanup, and zero goroutine leaks

+
+
+
+
+
+
+ + +
+
+
+

Supported Providers

+

Works with all major identity providers out of the box

+
+ + +
+
+
+ +
+

Google

+

Full OIDC

+
+
+
+ +
+

Azure AD

+

Full OIDC

+
+
+
+ A0 +
+

Auth0

+

Full OIDC

+
+
+
+ OK +
+

Okta

+

Full OIDC

+
+
+
+ KC +
+

Keycloak

+

Full OIDC

+
+
+
+ +
+

AWS Cognito

+

Full OIDC

+
+
+
+ +
+

GitLab

+

Full OIDC

+
+
+
+ +
+

GitHub

+

OAuth 2.0

+
+
+ + +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FeatureGoogleAzure ADAuth0OktaKeycloak
ID Tokens
Refresh Tokens
Auto-Configuration
Custom ClaimsLimited
Group/Role ClaimsLimited
Self-Hosted
+
+
+
+
+ + +
+
+
+

Installation

+

Get started in under 5 minutes

+
+
+
+

+ 1 + Enable the Plugin +

+

Add to your Traefik static configuration:

+
# traefik.yml
+experimental:
+  plugins:
+    traefikoidc:
+      moduleName: github.com/lukaszraczylo/traefikoidc
+      version: v0.7.10
+
+
+

+ 2 + Configure the Middleware +

+

Create your middleware configuration:

+
# dynamic/middleware.yml
+http:
+  middlewares:
+    oidc-auth:
+      plugin:
+        traefikoidc:
+          providerURL: "https://accounts.google.com"
+          clientID: "your-client-id"
+          clientSecret: "your-client-secret"
+          callbackURL: "/oauth2/callback"
+          sessionEncryptionKey: "your-32-byte-secret-key-here!!"
+          scopes:
+            - "openid"
+            - "profile"
+            - "email"
+
+
+

+ 3 + Apply to Your Routes +

+

Use the middleware on your services:

+
# dynamic/routers.yml
+http:
+  routers:
+    my-secure-app:
+      rule: "Host(`app.example.com`)"
+      service: my-service
+      middlewares:
+        - oidc-auth
+      tls:
+        certResolver: letsencrypt
+
+
+
+
+ + +
+
+
+

Configuration

+

Flexible options for any deployment scenario

+
+
+
+

Required Parameters

+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ParameterDescription
providerURLBase URL of your OIDC provider
clientIDOAuth 2.0 client identifier
clientSecretOAuth 2.0 client secret
sessionEncryptionKey32+ byte key for session encryption
callbackURLOAuth callback path (e.g., /oauth2/callback)
+
+
+
+

Popular Optional Parameters

+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ParameterDefaultDescription
forceHTTPSfalseRequired for TLS termination at load balancer
allowedUserDomainsnoneRestrict to specific email domains
allowedRolesAndGroupsnoneRestrict to users with specific roles
excludedURLsnonePaths that bypass authentication
enablePKCEfalseEnable PKCE for enhanced security
rateLimit100Maximum requests per second
+
+
+
+

Example: Google Workspace with Domain Restriction

+
http:
+  middlewares:
+    google-oidc:
+      plugin:
+        traefikoidc:
+          providerURL: "https://accounts.google.com"
+          clientID: "1234567890.apps.googleusercontent.com"
+          clientSecret: "your-client-secret"
+          callbackURL: "/oauth2/callback"
+          sessionEncryptionKey: "your-32-byte-encryption-key!!"
+          allowedUserDomains:
+            - "yourcompany.com"
+            - "subsidiary.com"
+          excludedURLs:
+            - "/health"
+            - "/metrics"
+            - "/api/public"
+          forceHTTPS: true
+          logLevel: "info"
+
+
+
+
+ + +
+
+
+

Security First

+

Built with enterprise security requirements in mind

+
+
+
+
+

+ + Token Security +

+
    +
  • • JWT signature verification with JWK rotation
  • +
  • • Replay attack detection via JTI claims
  • +
  • • Strict audience and issuer validation
  • +
  • • Automatic token refresh before expiry
  • +
  • • Token revocation on logout
  • +
+
+
+

+ + Session Security +

+
    +
  • • AES-256-GCM encrypted session cookies
  • +
  • • CSRF protection with state parameter
  • +
  • • Secure, HttpOnly, SameSite cookies
  • +
  • • Configurable session timeouts
  • +
  • • Bounded session cache with LRU eviction
  • +
+
+
+

+ + Security Headers +

+
    +
  • • Content Security Policy (CSP)
  • +
  • • HTTP Strict Transport Security (HSTS)
  • +
  • • X-Frame-Options, X-Content-Type-Options
  • +
  • • CORS configuration
  • +
  • • Customizable header profiles
  • +
+
+
+

+ + Rate Limiting +

+
    +
  • • Configurable request rate limits
  • +
  • • Protection against brute force attacks
  • +
  • • Per-client rate limiting
  • +
  • • Graceful handling of limit exceeded
  • +
  • • Customizable response codes
  • +
+
+
+
+
+
+ + +
+
+
+

Why Choose Traefik OIDC?

+

A better alternative to oauth2-proxy and forward-auth

+
+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FeatureTraefik OIDCoauth2-proxyforward-auth
Native Plugin
No Extra Service
Auto Provider Detection
Dynamic Client Registration
Automatic Scope Filtering
Built-in Security Headers
Template Headers
Memory Efficient LRU cachesVariesVaries
+
+
+
+
+
+ + +
+
+

Ready to Secure Your Applications?

+

+ Get started with Traefik OIDC in minutes. Full documentation and examples available on GitHub. +

+ +
+
+ + + + + + + diff --git a/dynamic_client_registration.go b/dynamic_client_registration.go new file mode 100644 index 0000000..e7ba0b0 --- /dev/null +++ b/dynamic_client_registration.go @@ -0,0 +1,550 @@ +// Package traefikoidc provides OIDC authentication middleware for Traefik +package traefikoidc + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + "time" +) + +// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591) +type ClientRegistrationResponse struct { + // Required fields + ClientID string `json:"client_id"` + + // Conditional - only for confidential clients + ClientSecret string `json:"client_secret,omitempty"` + + // Optional - for managing registration + RegistrationAccessToken string `json:"registration_access_token,omitempty"` + RegistrationClientURI string `json:"registration_client_uri,omitempty"` + + // Expiration + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + + // Echo back of registered metadata + RedirectURIs []string `json:"redirect_uris,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ApplicationType string `json:"application_type,omitempty"` + Contacts []string `json:"contacts,omitempty"` + ClientName string `json:"client_name,omitempty"` + LogoURI string `json:"logo_uri,omitempty"` + ClientURI string `json:"client_uri,omitempty"` + PolicyURI string `json:"policy_uri,omitempty"` + TOSURI string `json:"tos_uri,omitempty"` + JWKSURI string `json:"jwks_uri,omitempty"` + SubjectType string `json:"subject_type,omitempty"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// ClientRegistrationError represents an error response from client registration (RFC 7591) +type ClientRegistrationError struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591) +type DynamicClientRegistrar struct { + httpClient *http.Client + logger *Logger + config *DynamicClientRegistrationConfig + providerURL string + + // Cached registration response + mu sync.RWMutex + registrationResponse *ClientRegistrationResponse +} + +// NewDynamicClientRegistrar creates a new dynamic client registrar +func NewDynamicClientRegistrar( + httpClient *http.Client, + logger *Logger, + dcrConfig *DynamicClientRegistrationConfig, + providerURL string, +) *DynamicClientRegistrar { + if logger == nil { + logger = GetSingletonNoOpLogger() + } + + return &DynamicClientRegistrar{ + httpClient: httpClient, + logger: logger, + config: dcrConfig, + providerURL: providerURL, + } +} + +// RegisterClient performs dynamic client registration with the OIDC provider +// It first attempts to load existing credentials from a file if persistence is enabled, +// then registers a new client if no valid credentials exist. +func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) { + if r.config == nil || !r.config.Enabled { + return nil, fmt.Errorf("dynamic client registration is not enabled") + } + + // Try to load existing credentials if persistence is enabled + if r.config.PersistCredentials { + if resp, err := r.loadCredentials(); err == nil && resp != nil { + // Check if credentials are still valid (not expired) + if r.areCredentialsValid(resp) { + r.logger.Info("Loaded existing client credentials from file") + r.mu.Lock() + r.registrationResponse = resp + r.mu.Unlock() + return resp, nil + } + r.logger.Info("Existing credentials expired or invalid, registering new client") + } + } + + // Determine registration endpoint + endpoint := registrationEndpoint + if r.config.RegistrationEndpoint != "" { + endpoint = r.config.RegistrationEndpoint + } + + if endpoint == "" { + return nil, fmt.Errorf("no registration endpoint available: provider does not support dynamic client registration or endpoint not configured") + } + + // Validate the endpoint URL + if !strings.HasPrefix(endpoint, "https://") { + // Allow http only for localhost/development + if !strings.HasPrefix(endpoint, "http://localhost") && !strings.HasPrefix(endpoint, "http://127.0.0.1") { + return nil, fmt.Errorf("registration endpoint must use HTTPS for security") + } + r.logger.Infof("Warning: using insecure HTTP for registration endpoint (development only): %s", endpoint) + } + + // Build registration request + reqBody, err := r.buildRegistrationRequest() + if err != nil { + return nil, fmt.Errorf("failed to build registration request: %w", err) + } + + r.logger.Debugf("Registering client at endpoint: %s", endpoint) + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody)) + if err != nil { + return nil, fmt.Errorf("failed to create registration request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // Add Initial Access Token if provided + if r.config.InitialAccessToken != "" { + req.Header.Set("Authorization", "Bearer "+r.config.InitialAccessToken) + } + + // Execute request + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("registration request failed: %w", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit + if err != nil { + return nil, fmt.Errorf("failed to read registration response: %w", err) + } + + // Handle error responses + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + var regError ClientRegistrationError + if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" { + return nil, fmt.Errorf("registration failed: %s - %s", regError.Error, regError.ErrorDescription) + } + return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse successful response + var regResp ClientRegistrationResponse + if err := json.Unmarshal(body, ®Resp); err != nil { + return nil, fmt.Errorf("failed to parse registration response: %w", err) + } + + // Validate response + if regResp.ClientID == "" { + return nil, fmt.Errorf("registration response missing client_id") + } + + r.logger.Infof("Successfully registered client with ID: %s", regResp.ClientID) + + // Cache the response + r.mu.Lock() + r.registrationResponse = ®Resp + r.mu.Unlock() + + // Persist credentials if enabled + if r.config.PersistCredentials { + if err := r.saveCredentials(®Resp); err != nil { + r.logger.Errorf("Failed to persist client credentials: %v", err) + // Don't fail registration if persistence fails + } + } + + return ®Resp, nil +} + +// buildRegistrationRequest creates the JSON request body for client registration +func (r *DynamicClientRegistrar) buildRegistrationRequest() ([]byte, error) { + metadata := r.config.ClientMetadata + if metadata == nil { + metadata = &ClientRegistrationMetadata{} + } + + // Build request object + reqData := make(map[string]interface{}) + + // Required: redirect_uris + if len(metadata.RedirectURIs) > 0 { + reqData["redirect_uris"] = metadata.RedirectURIs + } else { + return nil, fmt.Errorf("redirect_uris is required for client registration") + } + + // Optional fields - only include if set + if len(metadata.ResponseTypes) > 0 { + reqData["response_types"] = metadata.ResponseTypes + } else { + // Default to authorization code flow + reqData["response_types"] = []string{"code"} + } + + if len(metadata.GrantTypes) > 0 { + reqData["grant_types"] = metadata.GrantTypes + } else { + // Default grant types for authorization code flow + reqData["grant_types"] = []string{"authorization_code", "refresh_token"} + } + + if metadata.ApplicationType != "" { + reqData["application_type"] = metadata.ApplicationType + } + + if len(metadata.Contacts) > 0 { + reqData["contacts"] = metadata.Contacts + } + + if metadata.ClientName != "" { + reqData["client_name"] = metadata.ClientName + } + + if metadata.LogoURI != "" { + reqData["logo_uri"] = metadata.LogoURI + } + + if metadata.ClientURI != "" { + reqData["client_uri"] = metadata.ClientURI + } + + if metadata.PolicyURI != "" { + reqData["policy_uri"] = metadata.PolicyURI + } + + if metadata.TOSURI != "" { + reqData["tos_uri"] = metadata.TOSURI + } + + if metadata.JWKSURI != "" { + reqData["jwks_uri"] = metadata.JWKSURI + } + + if metadata.SubjectType != "" { + reqData["subject_type"] = metadata.SubjectType + } + + if metadata.TokenEndpointAuthMethod != "" { + reqData["token_endpoint_auth_method"] = metadata.TokenEndpointAuthMethod + } else { + // Default to client_secret_basic for confidential clients + reqData["token_endpoint_auth_method"] = "client_secret_basic" + } + + if metadata.DefaultMaxAge > 0 { + reqData["default_max_age"] = metadata.DefaultMaxAge + } + + if metadata.RequireAuthTime { + reqData["require_auth_time"] = metadata.RequireAuthTime + } + + if len(metadata.DefaultACRValues) > 0 { + reqData["default_acr_values"] = metadata.DefaultACRValues + } + + if metadata.Scope != "" { + reqData["scope"] = metadata.Scope + } + + return json.Marshal(reqData) +} + +// GetCachedResponse returns the cached registration response +func (r *DynamicClientRegistrar) GetCachedResponse() *ClientRegistrationResponse { + r.mu.RLock() + defer r.mu.RUnlock() + return r.registrationResponse +} + +// areCredentialsValid checks if the cached credentials are still valid +func (r *DynamicClientRegistrar) areCredentialsValid(resp *ClientRegistrationResponse) bool { + if resp == nil || resp.ClientID == "" { + return false + } + + // Check if secret has expired + if resp.ClientSecretExpiresAt > 0 { + expiresAt := time.Unix(resp.ClientSecretExpiresAt, 0) + // Add 5 minute buffer before expiration + if time.Now().Add(5 * time.Minute).After(expiresAt) { + return false + } + } + + return true +} + +// credentialsFilePath returns the path for storing credentials +func (r *DynamicClientRegistrar) credentialsFilePath() string { + if r.config.CredentialsFile != "" { + return r.config.CredentialsFile + } + return "/tmp/oidc-client-credentials.json" +} + +// saveCredentials persists client credentials to a file +func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error { + filePath := r.credentialsFilePath() + + data, err := json.MarshalIndent(resp, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal credentials: %w", err) + } + + // Write with restrictive permissions (owner read/write only) + if err := os.WriteFile(filePath, data, 0600); err != nil { + return fmt.Errorf("failed to write credentials file: %w", err) + } + + r.logger.Debugf("Saved client credentials to %s", filePath) + return nil +} + +// loadCredentials loads client credentials from a file +func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) { + filePath := r.credentialsFilePath() + + data, err := os.ReadFile(filePath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil // No credentials file exists + } + return nil, fmt.Errorf("failed to read credentials file: %w", err) + } + + var resp ClientRegistrationResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("failed to parse credentials file: %w", err) + } + + return &resp, nil +} + +// UpdateClientRegistration updates an existing client registration using RFC 7592 +// This requires the registration_client_uri and registration_access_token from the original registration +func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) { + r.mu.RLock() + cachedResp := r.registrationResponse + r.mu.RUnlock() + + if cachedResp == nil { + return nil, fmt.Errorf("no existing registration to update") + } + + if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" { + return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token") + } + + // Build update request + reqBody, err := r.buildRegistrationRequest() + if err != nil { + return nil, fmt.Errorf("failed to build update request: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody)) + if err != nil { + return nil, fmt.Errorf("failed to create update request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken) + + // Execute request + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("update request failed: %w", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read update response: %w", err) + } + + // Handle error responses + if resp.StatusCode != http.StatusOK { + var regError ClientRegistrationError + if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" { + return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription) + } + return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse successful response + var regResp ClientRegistrationResponse + if err := json.Unmarshal(body, ®Resp); err != nil { + return nil, fmt.Errorf("failed to parse update response: %w", err) + } + + // Update cache + r.mu.Lock() + r.registrationResponse = ®Resp + r.mu.Unlock() + + // Persist updated credentials if enabled + if r.config.PersistCredentials { + if err := r.saveCredentials(®Resp); err != nil { + r.logger.Errorf("Failed to persist updated credentials: %v", err) + } + } + + r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID) + return ®Resp, nil +} + +// ReadClientRegistration reads the current client registration using RFC 7592 +func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) { + r.mu.RLock() + cachedResp := r.registrationResponse + r.mu.RUnlock() + + if cachedResp == nil { + return nil, fmt.Errorf("no existing registration to read") + } + + if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" { + return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token") + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil) + if err != nil { + return nil, fmt.Errorf("failed to create read request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken) + + // Execute request + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("read request failed: %w", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + // Handle error responses + if resp.StatusCode != http.StatusOK { + var regError ClientRegistrationError + if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" { + return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription) + } + return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse successful response + var regResp ClientRegistrationResponse + if err := json.Unmarshal(body, ®Resp); err != nil { + return nil, fmt.Errorf("failed to parse read response: %w", err) + } + + return ®Resp, nil +} + +// DeleteClientRegistration deletes the client registration using RFC 7592 +func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error { + r.mu.RLock() + cachedResp := r.registrationResponse + r.mu.RUnlock() + + if cachedResp == nil { + return fmt.Errorf("no existing registration to delete") + } + + if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" { + return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token") + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil) + if err != nil { + return fmt.Errorf("failed to create delete request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken) + + // Execute request + resp, err := r.httpClient.Do(req) + if err != nil { + return fmt.Errorf("delete request failed: %w", err) + } + defer resp.Body.Close() + + // Handle error responses (204 No Content is success) + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + var regError ClientRegistrationError + if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" { + return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription) + } + return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Clear cache + r.mu.Lock() + r.registrationResponse = nil + r.mu.Unlock() + + // Remove credentials file if persistence is enabled + if r.config.PersistCredentials { + filePath := r.credentialsFilePath() + if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { + r.logger.Errorf("Failed to remove credentials file: %v", err) + } + } + + r.logger.Info("Successfully deleted client registration") + return nil +} diff --git a/dynamic_client_registration_test.go b/dynamic_client_registration_test.go new file mode 100644 index 0000000..3ff9e99 --- /dev/null +++ b/dynamic_client_registration_test.go @@ -0,0 +1,1002 @@ +package traefikoidc + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" +) + +// TestDynamicClientRegistrarCreation tests creating a new DCR registrar +func TestDynamicClientRegistrarCreation(t *testing.T) { + tests := []struct { + name string + httpClient *http.Client + logger *Logger + dcrConfig *DynamicClientRegistrationConfig + providerURL string + }{ + { + name: "with all parameters", + httpClient: &http.Client{}, + logger: NewLogger("DEBUG"), + dcrConfig: &DynamicClientRegistrationConfig{ + Enabled: true, + ClientMetadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + ClientName: "Test Client", + }, + }, + providerURL: "https://example.com", + }, + { + name: "with nil logger", + httpClient: &http.Client{}, + logger: nil, + dcrConfig: &DynamicClientRegistrationConfig{ + Enabled: true, + }, + providerURL: "https://example.com", + }, + { + name: "with nil config", + httpClient: &http.Client{}, + logger: NewLogger("DEBUG"), + dcrConfig: nil, + providerURL: "https://example.com", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + registrar := NewDynamicClientRegistrar(tc.httpClient, tc.logger, tc.dcrConfig, tc.providerURL) + + if registrar == nil { + t.Fatal("Expected non-nil registrar") + } + + if registrar.httpClient != tc.httpClient { + t.Error("HTTP client not set correctly") + } + + if registrar.providerURL != tc.providerURL { + t.Errorf("Provider URL mismatch: got %s, want %s", registrar.providerURL, tc.providerURL) + } + + if registrar.config != tc.dcrConfig { + t.Error("Config not set correctly") + } + + // Logger should never be nil (fallback to no-op logger) + if registrar.logger == nil { + t.Error("Logger should not be nil") + } + }) + } +} + +// TestRegisterClientSuccess tests successful client registration +func TestRegisterClientSuccess(t *testing.T) { + // Create mock server that returns successful registration response + expectedClientID := "test-client-id-12345" + expectedClientSecret := "test-client-secret-67890" + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request + if r.Method != http.MethodPost { + t.Errorf("Expected POST request, got %s", r.Method) + } + + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type: application/json, got %s", r.Header.Get("Content-Type")) + } + + // Parse request body + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Errorf("Failed to decode request body: %v", err) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Verify redirect_uris is present + if _, ok := reqBody["redirect_uris"]; !ok { + t.Error("redirect_uris missing from request") + } + + // Return successful response + resp := ClientRegistrationResponse{ + ClientID: expectedClientID, + ClientSecret: expectedClientSecret, + ClientIDIssuedAt: time.Now().Unix(), + ClientSecretExpiresAt: 0, // Never expires + RedirectURIs: []string{"https://example.com/callback"}, + ResponseTypes: []string{"code"}, + GrantTypes: []string{"authorization_code", "refresh_token"}, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Errorf("Failed to encode response: %v", err) + } + })) + defer server.Close() + + // Create registrar + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + ClientMetadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + ClientName: "Test Client", + }, + } + + registrar := NewDynamicClientRegistrar( + server.Client(), + NewLogger("DEBUG"), + dcrConfig, + server.URL, + ) + + // Perform registration + ctx := context.Background() + resp, err := registrar.RegisterClient(ctx, server.URL+"/register") + + if err != nil { + t.Fatalf("Registration failed: %v", err) + } + + if resp.ClientID != expectedClientID { + t.Errorf("ClientID mismatch: got %s, want %s", resp.ClientID, expectedClientID) + } + + if resp.ClientSecret != expectedClientSecret { + t.Errorf("ClientSecret mismatch: got %s, want %s", resp.ClientSecret, expectedClientSecret) + } + + // Verify response is cached + cached := registrar.GetCachedResponse() + if cached == nil { + t.Fatal("Response should be cached") + } + if cached.ClientID != expectedClientID { + t.Errorf("Cached ClientID mismatch: got %s, want %s", cached.ClientID, expectedClientID) + } +} + +// TestRegisterClientWithInitialAccessToken tests registration with an initial access token +func TestRegisterClientWithInitialAccessToken(t *testing.T) { + expectedToken := "initial-access-token-12345" + receivedToken := "" + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture the authorization header + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + receivedToken = authHeader + } + + resp := ClientRegistrationResponse{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + InitialAccessToken: expectedToken, + ClientMetadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + } + + registrar := NewDynamicClientRegistrar( + server.Client(), + NewLogger("DEBUG"), + dcrConfig, + server.URL, + ) + + ctx := context.Background() + _, err := registrar.RegisterClient(ctx, server.URL+"/register") + + if err != nil { + t.Fatalf("Registration failed: %v", err) + } + + expectedAuthHeader := "Bearer " + expectedToken + if receivedToken != expectedAuthHeader { + t.Errorf("Authorization header mismatch: got %s, want %s", receivedToken, expectedAuthHeader) + } +} + +// TestRegisterClientError tests error handling during registration +func TestRegisterClientError(t *testing.T) { + tests := []struct { + name string + serverResponse func(w http.ResponseWriter, r *http.Request) + expectError bool + errorContains string + }{ + { + name: "invalid_redirect_uri error", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + resp := ClientRegistrationError{ + Error: "invalid_redirect_uri", + ErrorDescription: "The redirect_uri is not valid", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(resp) + }, + expectError: true, + errorContains: "invalid_redirect_uri", + }, + { + name: "invalid_client_metadata error", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + resp := ClientRegistrationError{ + Error: "invalid_client_metadata", + ErrorDescription: "Missing required field", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(resp) + }, + expectError: true, + errorContains: "invalid_client_metadata", + }, + { + name: "server error", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Server Error")) + }, + expectError: true, + errorContains: "500", + }, + { + name: "missing client_id in response", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + resp := map[string]string{ + "client_secret": "some-secret", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(resp) + }, + expectError: true, + errorContains: "missing client_id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(tc.serverResponse)) + defer server.Close() + + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + ClientMetadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + } + + registrar := NewDynamicClientRegistrar( + server.Client(), + NewLogger("DEBUG"), + dcrConfig, + server.URL, + ) + + ctx := context.Background() + _, err := registrar.RegisterClient(ctx, server.URL+"/register") + + if tc.expectError { + if err == nil { + t.Fatal("Expected error but got nil") + } + if tc.errorContains != "" && !stringContains(err.Error(), tc.errorContains) { + t.Errorf("Error should contain %q, got: %v", tc.errorContains, err) + } + } else { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + }) + } +} + +// TestRegisterClientDisabled tests that registration fails when not enabled +func TestRegisterClientDisabled(t *testing.T) { + tests := []struct { + name string + dcrConfig *DynamicClientRegistrationConfig + }{ + { + name: "nil config", + dcrConfig: nil, + }, + { + name: "enabled=false", + dcrConfig: &DynamicClientRegistrationConfig{ + Enabled: false, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + registrar := NewDynamicClientRegistrar( + &http.Client{}, + NewLogger("DEBUG"), + tc.dcrConfig, + "https://example.com", + ) + + ctx := context.Background() + _, err := registrar.RegisterClient(ctx, "https://example.com/register") + + if err == nil { + t.Fatal("Expected error when DCR is disabled") + } + + if !stringContains(err.Error(), "not enabled") { + t.Errorf("Error should mention 'not enabled', got: %v", err) + } + }) + } +} + +// TestRegisterClientMissingRedirectURIs tests that registration fails without redirect_uris +func TestRegisterClientMissingRedirectURIs(t *testing.T) { + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + ClientMetadata: &ClientRegistrationMetadata{ + ClientName: "Test Client", + // Missing RedirectURIs + }, + } + + registrar := NewDynamicClientRegistrar( + &http.Client{}, + NewLogger("DEBUG"), + dcrConfig, + "https://example.com", + ) + + ctx := context.Background() + _, err := registrar.RegisterClient(ctx, "https://example.com/register") + + if err == nil { + t.Fatal("Expected error when redirect_uris is missing") + } + + if !stringContains(err.Error(), "redirect_uris") { + t.Errorf("Error should mention 'redirect_uris', got: %v", err) + } +} + +// TestRegisterClientNoEndpoint tests that registration fails without an endpoint +func TestRegisterClientNoEndpoint(t *testing.T) { + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + ClientMetadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + } + + registrar := NewDynamicClientRegistrar( + &http.Client{}, + NewLogger("DEBUG"), + dcrConfig, + "https://example.com", + ) + + ctx := context.Background() + _, err := registrar.RegisterClient(ctx, "") // Empty endpoint + + if err == nil { + t.Fatal("Expected error when registration endpoint is missing") + } + + if !stringContains(err.Error(), "no registration endpoint") { + t.Errorf("Error should mention 'no registration endpoint', got: %v", err) + } +} + +// TestRegisterClientHTTPSRequired tests that HTTPS is required for non-localhost endpoints +func TestRegisterClientHTTPSRequired(t *testing.T) { + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + ClientMetadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + } + + registrar := NewDynamicClientRegistrar( + &http.Client{}, + NewLogger("DEBUG"), + dcrConfig, + "https://example.com", + ) + + ctx := context.Background() + _, err := registrar.RegisterClient(ctx, "http://example.com/register") // HTTP instead of HTTPS + + if err == nil { + t.Fatal("Expected error when using HTTP for non-localhost endpoint") + } + + if !stringContains(err.Error(), "HTTPS") { + t.Errorf("Error should mention 'HTTPS', got: %v", err) + } +} + +// TestRegisterClientCredentialsPersistence tests saving and loading credentials +func TestRegisterClientCredentialsPersistence(t *testing.T) { + // Create a temp file for credentials + tempDir := t.TempDir() + credentialsFile := filepath.Join(tempDir, "test-credentials.json") + + expectedClientID := "persisted-client-id" + expectedClientSecret := "persisted-client-secret" + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := ClientRegistrationResponse{ + ClientID: expectedClientID, + ClientSecret: expectedClientSecret, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + PersistCredentials: true, + CredentialsFile: credentialsFile, + ClientMetadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + } + + registrar := NewDynamicClientRegistrar( + server.Client(), + NewLogger("DEBUG"), + dcrConfig, + server.URL, + ) + + // First registration - should hit the server + ctx := context.Background() + resp, err := registrar.RegisterClient(ctx, server.URL+"/register") + if err != nil { + t.Fatalf("First registration failed: %v", err) + } + + if resp.ClientID != expectedClientID { + t.Errorf("ClientID mismatch: got %s, want %s", resp.ClientID, expectedClientID) + } + + // Verify credentials file was created + if _, err := os.Stat(credentialsFile); os.IsNotExist(err) { + t.Fatal("Credentials file was not created") + } + + // Create a new registrar to test loading + registrar2 := NewDynamicClientRegistrar( + server.Client(), + NewLogger("DEBUG"), + dcrConfig, + server.URL, + ) + + // Second registration - should load from file + resp2, err := registrar2.RegisterClient(ctx, server.URL+"/register") + if err != nil { + t.Fatalf("Second registration failed: %v", err) + } + + if resp2.ClientID != expectedClientID { + t.Errorf("Loaded ClientID mismatch: got %s, want %s", resp2.ClientID, expectedClientID) + } +} + +// TestCredentialsValidation tests the areCredentialsValid function +func TestCredentialsValidation(t *testing.T) { + dcrConfig := &DynamicClientRegistrationConfig{Enabled: true} + registrar := NewDynamicClientRegistrar(&http.Client{}, NewLogger("DEBUG"), dcrConfig, "https://example.com") + + tests := []struct { + name string + response *ClientRegistrationResponse + expected bool + }{ + { + name: "nil response", + response: nil, + expected: false, + }, + { + name: "empty client_id", + response: &ClientRegistrationResponse{ + ClientID: "", + }, + expected: false, + }, + { + name: "valid non-expiring credentials", + response: &ClientRegistrationResponse{ + ClientID: "test-client-id", + ClientSecretExpiresAt: 0, // Never expires + }, + expected: true, + }, + { + name: "valid future-expiring credentials", + response: &ClientRegistrationResponse{ + ClientID: "test-client-id", + ClientSecretExpiresAt: time.Now().Add(1 * time.Hour).Unix(), + }, + expected: true, + }, + { + name: "expired credentials", + response: &ClientRegistrationResponse{ + ClientID: "test-client-id", + ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), + }, + expected: false, + }, + { + name: "about to expire credentials (within 5 min buffer)", + response: &ClientRegistrationResponse{ + ClientID: "test-client-id", + ClientSecretExpiresAt: time.Now().Add(2 * time.Minute).Unix(), + }, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := registrar.areCredentialsValid(tc.response) + if result != tc.expected { + t.Errorf("areCredentialsValid() = %v, want %v", result, tc.expected) + } + }) + } +} + +// TestBuildRegistrationRequest tests the request body construction +func TestBuildRegistrationRequest(t *testing.T) { + tests := []struct { + name string + metadata *ClientRegistrationMetadata + expectedFields map[string]interface{} + expectError bool + }{ + { + name: "minimal metadata", + metadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + expectedFields: map[string]interface{}{ + "redirect_uris": []interface{}{"https://example.com/callback"}, + "response_types": []interface{}{"code"}, + "grant_types": []interface{}{"authorization_code", "refresh_token"}, + "token_endpoint_auth_method": "client_secret_basic", + }, + expectError: false, + }, + { + name: "full metadata", + metadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback", "https://example.com/callback2"}, + ResponseTypes: []string{"code", "token"}, + GrantTypes: []string{"authorization_code"}, + ApplicationType: "web", + Contacts: []string{"admin@example.com"}, + ClientName: "My Test Client", + LogoURI: "https://example.com/logo.png", + ClientURI: "https://example.com", + PolicyURI: "https://example.com/privacy", + TOSURI: "https://example.com/tos", + SubjectType: "public", + TokenEndpointAuthMethod: "client_secret_post", + DefaultMaxAge: 3600, + RequireAuthTime: true, + Scope: "openid profile email", + }, + expectedFields: map[string]interface{}{ + "redirect_uris": []interface{}{"https://example.com/callback", "https://example.com/callback2"}, + "response_types": []interface{}{"code", "token"}, + "grant_types": []interface{}{"authorization_code"}, + "application_type": "web", + "client_name": "My Test Client", + "token_endpoint_auth_method": "client_secret_post", + "default_max_age": float64(3600), + "require_auth_time": true, + "scope": "openid profile email", + }, + expectError: false, + }, + { + name: "missing redirect_uris", + metadata: &ClientRegistrationMetadata{ + ClientName: "Test Client", + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + ClientMetadata: tc.metadata, + } + + registrar := NewDynamicClientRegistrar( + &http.Client{}, + NewLogger("DEBUG"), + dcrConfig, + "https://example.com", + ) + + reqBody, err := registrar.buildRegistrationRequest() + + if tc.expectError { + if err == nil { + t.Fatal("Expected error but got nil") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + var reqData map[string]interface{} + if err := json.Unmarshal(reqBody, &reqData); err != nil { + t.Fatalf("Failed to unmarshal request body: %v", err) + } + + for field, expectedValue := range tc.expectedFields { + actualValue, ok := reqData[field] + if !ok { + t.Errorf("Missing expected field: %s", field) + continue + } + + // Compare JSON representations for slices + expectedJSON, _ := json.Marshal(expectedValue) + actualJSON, _ := json.Marshal(actualValue) + if string(expectedJSON) != string(actualJSON) { + t.Errorf("Field %s mismatch: got %v, want %v", field, actualValue, expectedValue) + } + } + }) + } +} + +// TestProviderMetadataRegistrationEndpoint tests that registration_endpoint is parsed from metadata +func TestProviderMetadataRegistrationEndpoint(t *testing.T) { + metadata := &ProviderMetadata{ + Issuer: "https://example.com", + AuthURL: "https://example.com/authorize", + TokenURL: "https://example.com/token", + JWKSURL: "https://example.com/.well-known/jwks.json", + RegistrationURL: "https://example.com/register", + } + + if metadata.RegistrationURL != "https://example.com/register" { + t.Errorf("RegistrationURL not set correctly: got %s", metadata.RegistrationURL) + } +} + +// TestDCRConfigDefaults tests default configuration values +func TestDCRConfigDefaults(t *testing.T) { + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + } + + registrar := NewDynamicClientRegistrar( + &http.Client{}, + NewLogger("DEBUG"), + dcrConfig, + "https://example.com", + ) + + // Test default credentials file path + path := registrar.credentialsFilePath() + if path != "/tmp/oidc-client-credentials.json" { + t.Errorf("Default credentials file path mismatch: got %s", path) + } + + // Test custom credentials file path + dcrConfig.CredentialsFile = "/custom/path/credentials.json" + path = registrar.credentialsFilePath() + if path != "/custom/path/credentials.json" { + t.Errorf("Custom credentials file path mismatch: got %s", path) + } +} + +// TestUpdateClientRegistration tests the RFC 7592 client update functionality +func TestUpdateClientRegistration(t *testing.T) { + updateCalled := false + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPut { + updateCalled = true + + // Verify authorization header + if r.Header.Get("Authorization") == "" { + t.Error("Missing Authorization header for update") + } + + resp := ClientRegistrationResponse{ + ClientID: "updated-client-id", + ClientSecret: "updated-client-secret", + RegistrationAccessToken: "new-access-token", + RegistrationClientURI: r.URL.String(), + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + ClientMetadata: &ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + } + + registrar := NewDynamicClientRegistrar( + server.Client(), + NewLogger("DEBUG"), + dcrConfig, + server.URL, + ) + + // Set up cached response with management credentials + registrar.mu.Lock() + registrar.registrationResponse = &ClientRegistrationResponse{ + ClientID: "original-client-id", + ClientSecret: "original-client-secret", + RegistrationAccessToken: "access-token", + RegistrationClientURI: server.URL + "/register/client123", + } + registrar.mu.Unlock() + + // Perform update + ctx := context.Background() + resp, err := registrar.UpdateClientRegistration(ctx) + + if err != nil { + t.Fatalf("Update failed: %v", err) + } + + if !updateCalled { + t.Error("Update endpoint was not called") + } + + if resp.ClientID != "updated-client-id" { + t.Errorf("Updated ClientID mismatch: got %s", resp.ClientID) + } +} + +// TestDeleteClientRegistration tests the RFC 7592 client deletion functionality +func TestDeleteClientRegistration(t *testing.T) { + deleteCalled := false + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodDelete { + deleteCalled = true + w.WriteHeader(http.StatusNoContent) + } + })) + defer server.Close() + + tempDir := t.TempDir() + credentialsFile := filepath.Join(tempDir, "credentials.json") + + // Create a credentials file to test deletion + os.WriteFile(credentialsFile, []byte(`{"client_id":"test"}`), 0600) + + dcrConfig := &DynamicClientRegistrationConfig{ + Enabled: true, + PersistCredentials: true, + CredentialsFile: credentialsFile, + } + + registrar := NewDynamicClientRegistrar( + server.Client(), + NewLogger("DEBUG"), + dcrConfig, + server.URL, + ) + + // Set up cached response with management credentials + registrar.mu.Lock() + registrar.registrationResponse = &ClientRegistrationResponse{ + ClientID: "test-client-id", + RegistrationAccessToken: "access-token", + RegistrationClientURI: server.URL + "/register/client123", + } + registrar.mu.Unlock() + + // Perform delete + ctx := context.Background() + err := registrar.DeleteClientRegistration(ctx) + + if err != nil { + t.Fatalf("Delete failed: %v", err) + } + + if !deleteCalled { + t.Error("Delete endpoint was not called") + } + + // Verify cache is cleared + if registrar.GetCachedResponse() != nil { + t.Error("Cached response should be cleared after deletion") + } + + // Verify credentials file is deleted + if _, err := os.Stat(credentialsFile); !os.IsNotExist(err) { + t.Error("Credentials file should be deleted") + } +} + +// TestReadClientRegistration tests the RFC 7592 client read functionality +func TestReadClientRegistration(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resp := ClientRegistrationResponse{ + ClientID: "read-client-id", + ClientSecret: "read-client-secret", + RedirectURIs: []string{"https://example.com/callback"}, + ResponseTypes: []string{"code"}, + GrantTypes: []string{"authorization_code"}, + ApplicationType: "web", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + } + })) + defer server.Close() + + dcrConfig := &DynamicClientRegistrationConfig{Enabled: true} + + registrar := NewDynamicClientRegistrar( + server.Client(), + NewLogger("DEBUG"), + dcrConfig, + server.URL, + ) + + // Set up cached response with management credentials + registrar.mu.Lock() + registrar.registrationResponse = &ClientRegistrationResponse{ + ClientID: "original-client-id", + RegistrationAccessToken: "access-token", + RegistrationClientURI: server.URL + "/register/client123", + } + registrar.mu.Unlock() + + // Read registration + ctx := context.Background() + resp, err := registrar.ReadClientRegistration(ctx) + + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if resp.ClientID != "read-client-id" { + t.Errorf("Read ClientID mismatch: got %s", resp.ClientID) + } +} + +// TestOperationsWithoutCachedResponse tests error handling when no cached response exists +func TestOperationsWithoutCachedResponse(t *testing.T) { + dcrConfig := &DynamicClientRegistrationConfig{Enabled: true} + + registrar := NewDynamicClientRegistrar( + &http.Client{}, + NewLogger("DEBUG"), + dcrConfig, + "https://example.com", + ) + + ctx := context.Background() + + // Test Update without cached response + _, err := registrar.UpdateClientRegistration(ctx) + if err == nil || !stringContains(err.Error(), "no existing registration") { + t.Errorf("Update should fail without cached response: %v", err) + } + + // Test Read without cached response + _, err = registrar.ReadClientRegistration(ctx) + if err == nil || !stringContains(err.Error(), "no existing registration") { + t.Errorf("Read should fail without cached response: %v", err) + } + + // Test Delete without cached response + err = registrar.DeleteClientRegistration(ctx) + if err == nil || !stringContains(err.Error(), "no existing registration") { + t.Errorf("Delete should fail without cached response: %v", err) + } +} + +// TestOperationsWithoutManagementCredentials tests error handling without management URIs +func TestOperationsWithoutManagementCredentials(t *testing.T) { + dcrConfig := &DynamicClientRegistrationConfig{Enabled: true} + + registrar := NewDynamicClientRegistrar( + &http.Client{}, + NewLogger("DEBUG"), + dcrConfig, + "https://example.com", + ) + + // Set up cached response WITHOUT management credentials + registrar.mu.Lock() + registrar.registrationResponse = &ClientRegistrationResponse{ + ClientID: "test-client-id", + // Missing RegistrationAccessToken and RegistrationClientURI + } + registrar.mu.Unlock() + + ctx := context.Background() + + // Test Update without management credentials + _, err := registrar.UpdateClientRegistration(ctx) + if err == nil || !stringContains(err.Error(), "registration management not supported") { + t.Errorf("Update should fail without management credentials: %v", err) + } + + // Test Read without management credentials + _, err = registrar.ReadClientRegistration(ctx) + if err == nil || !stringContains(err.Error(), "registration management not supported") { + t.Errorf("Read should fail without management credentials: %v", err) + } + + // Test Delete without management credentials + err = registrar.DeleteClientRegistration(ctx) + if err == nil || !stringContains(err.Error(), "registration management not supported") { + t.Errorf("Delete should fail without management credentials: %v", err) + } +} + +// stringContains is a helper function to check if a string contains a substring +func stringContains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContainsHelper(s, substr)) +} + +func stringContainsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/http_client_pool.go b/http_client_pool.go index 3ad1c83..8bdebc6 100644 --- a/http_client_pool.go +++ b/http_client_pool.go @@ -146,6 +146,9 @@ func (p *SharedTransportPool) ReleaseTransport(transport *http.Transport) { } // cleanupIdleTransports periodically cleans up unused transports +// Uses two-phase cleanup to minimize lock contention: +// 1. Find candidates while holding read lock +// 2. Remove and close transports with minimal lock duration func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() @@ -155,17 +158,46 @@ func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - p.mu.Lock() - now := time.Now() - for transportKey, shared := range p.transports { - // Clean up transports not used for 2 minutes with no references - if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute { - shared.transport.CloseIdleConnections() - delete(p.transports, transportKey) - // SECURITY FIX: Decrement client count when removing transport - atomic.AddInt32(&p.clientCount, -1) - } + p.performCleanup() + } + } +} + +// performCleanup does the actual cleanup with optimized locking +func (p *SharedTransportPool) performCleanup() { + now := time.Now() + + // Phase 1: Find candidates while holding read lock (fast) + p.mu.RLock() + candidates := make([]string, 0) + for transportKey, shared := range p.transports { + // Clean up transports not used for 2 minutes with no references + if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute { + candidates = append(candidates, transportKey) + } + } + p.mu.RUnlock() + + if len(candidates) == 0 { + return + } + + // Phase 2: Remove and close each candidate individually + // This minimizes lock contention and allows concurrent access + for _, key := range candidates { + p.mu.Lock() + shared, exists := p.transports[key] + if exists && shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute { + // Remove from map first (releases memory) + delete(p.transports, key) + atomic.AddInt32(&p.clientCount, -1) + p.mu.Unlock() + + // Close idle connections outside the lock (can be slow) + if shared.transport != nil { + shared.transport.CloseIdleConnections() } + } else { p.mu.Unlock() } } diff --git a/jwt.go b/jwt.go index 16cf121..dee9fbe 100644 --- a/jwt.go +++ b/jwt.go @@ -18,13 +18,16 @@ import ( "github.com/lukaszraczylo/traefikoidc/internal/pool" ) -// Replay attack protection cache and synchronization primitives. +// Replay attack protection cache using sharded design for reduced lock contention. // This cache tracks JWT IDs (jti claims) to prevent token reuse attacks. +// Under high load (500+ req/sec), the sharded design reduces contention significantly. var ( - // replayCacheMu protects access to the replay cache instance + // replayCacheMu protects access to the replay cache instance (only used for initialization) replayCacheMu sync.RWMutex - // replayCache stores JWT IDs with expiration to prevent replay attacks + // replayCache stores JWT IDs with expiration to prevent replay attacks (legacy interface) replayCache CacheInterface + // shardedReplayCache is the new high-performance sharded cache for replay detection + shardedReplayCache *ShardedCache // replayCacheOnce ensures the replay cache is initialized only once replayCacheOnce sync.Once // replayCacheCleanupWG waits for cleanup goroutine to finish @@ -36,10 +39,20 @@ var ( ) // initReplayCache initializes the JWT replay protection cache with bounded size. +// Uses a sharded cache design with 64 shards for reduced lock contention under high load. // The cache is bounded to 10,000 entries to prevent unbounded memory growth. // This function uses sync.Once to ensure thread-safe single initialization. func initReplayCache() { replayCacheOnce.Do(func() { + // Hold mutex during initialization to synchronize with cleanup goroutine + replayCacheMu.Lock() + defer replayCacheMu.Unlock() + + // Create sharded cache with 64 shards for reduced contention + // Under 500 req/sec, this reduces lock contention by ~64x compared to single mutex + shardedReplayCache = NewShardedCache(64, 10000) + + // Also initialize legacy cache for backward compatibility replayCache = NewCache() replayCache.SetMaxSize(10000) }) @@ -65,18 +78,32 @@ func cleanupReplayCache() { replayCacheMu.Lock() defer replayCacheMu.Unlock() + // Clear sharded cache + if shardedReplayCache != nil { + shardedReplayCache.Clear() + shardedReplayCache = nil + } + + // Clear legacy cache if replayCache != nil { replayCache.Close() replayCache = nil - replayCacheOnce = sync.Once{} } + + replayCacheOnce = sync.Once{} } // getReplayCacheStats returns statistics about the replay cache state. // Returns: -// - size: Current number of entries in the cache (currently always 0 due to interface limitations) +// - size: Current number of entries in the cache // - maxSize: Maximum allowed entries (10,000) func getReplayCacheStats() (size int, maxSize int) { + // Use sharded cache if available (no mutex needed due to internal sharding) + if shardedReplayCache != nil { + return shardedReplayCache.Size(), 10000 + } + + // Fall back to legacy cache replayCacheMu.RLock() defer replayCacheMu.RUnlock() @@ -98,16 +125,31 @@ func startReplayCacheCleanup(ctx context.Context, logger *Logger) { // Define the cleanup task function cleanupFunc := func() { + // Use mutex to safely access cache pointers - this prevents race with initReplayCache + replayCacheMu.RLock() + shardedCache := shardedReplayCache + legacyCache := replayCache + replayCacheMu.RUnlock() + + // Only proceed if caches have been initialized + if shardedCache == nil && legacyCache == nil { + return + } + size, maxSize := getReplayCacheStats() if logger != nil { logger.Debugf("Replay cache stats: size=%d, maxSize=%d", size, maxSize) } - replayCacheMu.RLock() - if replayCache != nil { - replayCache.Cleanup() + // Clean up sharded cache + if shardedCache != nil { + shardedCache.Cleanup() + } + + // Also clean up legacy cache for backward compatibility + if legacyCache != nil { + legacyCache.Cleanup() } - replayCacheMu.RUnlock() } // Create or get singleton cleanup task @@ -323,29 +365,51 @@ func (j *JWT) Verify(issuerURL, expectedAudience string, skipReplayCheck ...bool if jtiOk && !shouldSkipReplay && jtiValue != "" { initReplayCache() - replayCacheMu.RLock() - _, exists := replayCache.Get(jtiValue) - replayCacheMu.RUnlock() - - if exists { - return fmt.Errorf("token replay detected (jti: %s)", jtiValue) - } - - expFloat, ok := claims["exp"].(float64) - var expTime time.Time - if ok { - expTime = time.Unix(int64(expFloat), 0) - } else { - expTime = time.Now().Add(10 * time.Minute) - } - - duration := time.Until(expTime) - if duration > 0 { - replayCacheMu.Lock() - if replayCache != nil { - replayCache.Set(jtiValue, true, duration) + // Use sharded cache for replay detection - no global mutex needed + // This reduces lock contention by ~64x under high load + if shardedReplayCache != nil { + if shardedReplayCache.Exists(jtiValue) { + return fmt.Errorf("token replay detected (jti: %s)", jtiValue) + } + + expFloat, ok := claims["exp"].(float64) + var expTime time.Time + if ok { + expTime = time.Unix(int64(expFloat), 0) + } else { + expTime = time.Now().Add(10 * time.Minute) + } + + duration := time.Until(expTime) + if duration > 0 { + shardedReplayCache.Set(jtiValue, true, duration) + } + } else { + // Fall back to legacy cache with mutex (should rarely happen) + replayCacheMu.RLock() + _, exists := replayCache.Get(jtiValue) + replayCacheMu.RUnlock() + + if exists { + return fmt.Errorf("token replay detected (jti: %s)", jtiValue) + } + + expFloat, ok := claims["exp"].(float64) + var expTime time.Time + if ok { + expTime = time.Unix(int64(expFloat), 0) + } else { + expTime = time.Now().Add(10 * time.Minute) + } + + duration := time.Until(expTime) + if duration > 0 { + replayCacheMu.Lock() + if replayCache != nil { + replayCache.Set(jtiValue, true, duration) + } + replayCacheMu.Unlock() } - replayCacheMu.Unlock() } } diff --git a/main.go b/main.go index 9637952..884df79 100644 --- a/main.go +++ b/main.go @@ -205,6 +205,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name suppressDiagnosticLogs: isTestMode(), securityHeadersApplier: config.GetSecurityHeadersApplier(), scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering + dcrConfig: config.DynamicClientRegistration, } // Log audience configuration @@ -361,12 +362,13 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { // updateMetadataEndpoints updates internal endpoint URLs with discovered metadata. // It sets the authorization URL, token URL, JWKS URL, issuer URL, revocation URL, -// end session URL, and introspection URL based on the provider's metadata. +// end session URL, introspection URL, and registration URL based on the provider's metadata. +// If Dynamic Client Registration is enabled and no ClientID is configured, it will +// automatically register the client with the provider. // Parameters: // - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints. func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { t.metadataMu.Lock() - defer t.metadataMu.Unlock() t.jwksURL = metadata.JWKSURL t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery @@ -376,6 +378,9 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { t.revocationURL = metadata.RevokeURL t.endSessionURL = metadata.EndSessionURL t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662) + t.registrationURL = metadata.RegistrationURL // OIDC Dynamic Client Registration endpoint (RFC 7591) + + t.metadataMu.Unlock() // Log introspection endpoint availability for opaque token support if t.introspectionURL != "" { @@ -386,6 +391,67 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { } else if t.allowOpaqueTokens || t.requireTokenIntrospection { t.logger.Infof("⚠️ Opaque tokens enabled but no introspection endpoint available from provider") } + + // Log registration endpoint availability + if t.registrationURL != "" { + t.logger.Debugf("Dynamic client registration endpoint discovered: %s", t.registrationURL) + } + + // Perform Dynamic Client Registration if enabled and ClientID is not set + if t.dcrConfig != nil && t.dcrConfig.Enabled && t.clientID == "" { + t.performDynamicClientRegistration() + } +} + +// performDynamicClientRegistration performs automatic client registration with the OIDC provider +func (t *TraefikOidc) performDynamicClientRegistration() { + t.logger.Info("Dynamic Client Registration enabled - registering client with provider") + + // Initialize the DCR registrar if not already done + if t.dynamicClientRegistrar == nil { + t.dynamicClientRegistrar = NewDynamicClientRegistrar( + t.httpClient, + t.logger, + t.dcrConfig, + t.providerURL, + ) + } + + // Get registration endpoint (from metadata or config override) + registrationEndpoint := t.registrationURL + if t.dcrConfig.RegistrationEndpoint != "" { + registrationEndpoint = t.dcrConfig.RegistrationEndpoint + } + + // Perform registration + ctx, cancel := context.WithTimeout(t.ctx, 30*time.Second) + defer cancel() + + resp, err := t.dynamicClientRegistrar.RegisterClient(ctx, registrationEndpoint) + if err != nil { + t.logger.Errorf("Dynamic Client Registration failed: %v", err) + return + } + + // Update client credentials from registration response + t.metadataMu.Lock() + t.clientID = resp.ClientID + t.clientSecret = resp.ClientSecret + if t.audience == "" { + t.audience = resp.ClientID // Default audience to client ID + } + t.metadataMu.Unlock() + + t.logger.Infof("Dynamic Client Registration successful - client_id: %s", resp.ClientID) + + // Log additional registration details + if resp.ClientSecretExpiresAt > 0 { + expiresAt := time.Unix(resp.ClientSecretExpiresAt, 0) + t.logger.Infof("Client secret expires at: %s", expiresAt.Format(time.RFC3339)) + } + if resp.RegistrationClientURI != "" { + t.logger.Debugf("Registration management URI: %s", resp.RegistrationClientURI) + } } // startMetadataRefresh starts a background goroutine that periodically refreshes provider metadata. diff --git a/main_test.go b/main_test.go index 931a1e8..30091f4 100644 --- a/main_test.go +++ b/main_test.go @@ -3214,10 +3214,8 @@ func TestAuthenticationFlowReplayDetection(t *testing.T) { t.Fatalf("Initial authentication should succeed: %v", err) } - // Verify JTI is in cache - replayCacheMu.Lock() - _, exists := replayCache.Get(jti) - replayCacheMu.Unlock() + // Verify JTI is in cache (use shardedReplayCache which is the actual cache used) + exists := shardedReplayCache.Exists(jti) if !exists { t.Error("JTI should be added to replay cache during initial authentication") } @@ -3398,14 +3396,12 @@ func TestConcurrentTokenValidation(t *testing.T) { t.Errorf("Expected no errors in concurrent validation, got %d errors: %v", len(errors), errors) } - // Verify all JTIs are in cache - replayCacheMu.Lock() + // Verify all JTIs are in cache (use shardedReplayCache which is the actual cache used) for i, jti := range jtis { - if _, exists := replayCache.Get(jti); !exists { + if !shardedReplayCache.Exists(jti) { t.Errorf("JTI %d (%s) should be in replay cache", i, jti) } } - replayCacheMu.Unlock() } // TestJTIBlacklistBehavior tests the JTI blacklist cache management @@ -3458,9 +3454,8 @@ func TestJTIBlacklistBehavior(t *testing.T) { { name: "JTI exists in blacklist after verification", action: func() error { - replayCacheMu.RLock() - defer replayCacheMu.RUnlock() - if _, exists := replayCache.Get(jti); !exists { + // Use shardedReplayCache which is the actual cache used + if !shardedReplayCache.Exists(jti) { return fmt.Errorf("JTI not found in blacklist cache") } return nil @@ -3567,9 +3562,8 @@ func TestSessionBasedTokenRevalidation(t *testing.T) { } // Check replay cache - replayCacheMu.Lock() - _, inReplayCache := replayCache.Get(jti) - replayCacheMu.Unlock() + // Use shardedReplayCache which is the actual cache used + inReplayCache := shardedReplayCache.Exists(jti) if !inReplayCache { t.Error("JTI should be in replay cache") } diff --git a/refresh_coordinator.go b/refresh_coordinator.go index 5c31464..bb8a238 100644 --- a/refresh_coordinator.go +++ b/refresh_coordinator.go @@ -40,6 +40,13 @@ type RefreshCoordinator struct { // Cleanup goroutine control stopChan chan struct{} wg sync.WaitGroup + + // delayedCleanupQueue stores items to be cleaned up after delay + // Uses a timer-based approach instead of spawning goroutines per cleanup + delayedCleanupQueue chan delayedCleanupItem + // cleanupTimerPool reuses timers to avoid goroutine-per-cleanup + cleanupTimerMu sync.Mutex + cleanupTimers map[string]*time.Timer } // RefreshCoordinatorConfig configures the refresh coordinator behavior @@ -131,6 +138,12 @@ type RefreshMetrics struct { currentInFlightRefreshes int32 } +// delayedCleanupItem represents an item scheduled for delayed cleanup +type delayedCleanupItem struct { + tokenHash string + cleanupAt time.Time +} + // RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations type RefreshCircuitBreaker struct { state int32 // 0=closed, 1=open, 2=half-open @@ -161,6 +174,8 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref metrics: &RefreshMetrics{}, logger: logger, stopChan: make(chan struct{}), + delayedCleanupQueue: make(chan delayedCleanupItem, 1000), // Buffered channel for cleanup items + cleanupTimers: make(map[string]*time.Timer), circuitBreaker: &RefreshCircuitBreaker{ config: RefreshCircuitBreakerConfig{ MaxFailures: 3, @@ -174,6 +189,10 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref rc.wg.Add(1) go rc.cleanupRoutine() + // Start delayed cleanup processor (single goroutine processes all cleanup timers) + rc.wg.Add(1) + go rc.processDelayedCleanups() + return rc } @@ -313,16 +332,9 @@ func (rc *RefreshCoordinator) executeRefreshAsync( // Signal completion to all waiters close(operation.done) - // Clean up operation after a configurable delay to allow waiters to read result - go func() { - if rc.config.DeduplicationCleanupDelay > 0 { - time.Sleep(rc.config.DeduplicationCleanupDelay) - } - rc.refreshMutex.Lock() - delete(rc.inFlightRefreshes, tokenHash) - rc.refreshMutex.Unlock() - atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1) - }() + // Schedule delayed cleanup using timer instead of spawning a goroutine + // This prevents goroutine explosion under high load + rc.scheduleDelayedCleanup(tokenHash) }() // Create timeout context @@ -369,6 +381,65 @@ func (rc *RefreshCoordinator) executeRefreshAsync( } } +// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning a goroutine +// This prevents goroutine explosion under high load (500+ req/sec) +func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) { + delay := rc.config.DeduplicationCleanupDelay + if delay <= 0 { + // Immediate cleanup + rc.performCleanup(tokenHash) + return + } + + // Use time.AfterFunc which is more efficient than spawning a goroutine with Sleep + // time.AfterFunc uses the runtime's timer heap which is much more efficient + rc.cleanupTimerMu.Lock() + // Cancel any existing timer for this hash (shouldn't happen, but just in case) + if existingTimer, exists := rc.cleanupTimers[tokenHash]; exists { + existingTimer.Stop() + } + rc.cleanupTimers[tokenHash] = time.AfterFunc(delay, func() { + rc.performCleanup(tokenHash) + // Remove timer from map + rc.cleanupTimerMu.Lock() + delete(rc.cleanupTimers, tokenHash) + rc.cleanupTimerMu.Unlock() + }) + rc.cleanupTimerMu.Unlock() +} + +// performCleanup removes the operation from the in-flight map +func (rc *RefreshCoordinator) performCleanup(tokenHash string) { + rc.refreshMutex.Lock() + delete(rc.inFlightRefreshes, tokenHash) + rc.refreshMutex.Unlock() + atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1) +} + +// processDelayedCleanups processes delayed cleanup requests from the queue +// This is a single goroutine that handles all delayed cleanups +func (rc *RefreshCoordinator) processDelayedCleanups() { + defer rc.wg.Done() + + for { + select { + case item := <-rc.delayedCleanupQueue: + // Wait until cleanup time + waitDuration := time.Until(item.cleanupAt) + if waitDuration > 0 { + select { + case <-time.After(waitDuration): + case <-rc.stopChan: + return + } + } + rc.performCleanup(item.tokenHash) + case <-rc.stopChan: + return + } + } +} + // isInCooldown checks if a session is in cooldown after recording an attempt func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool { rc.attemptsMutex.Lock() @@ -516,6 +587,15 @@ func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} { // Shutdown gracefully shuts down the coordinator func (rc *RefreshCoordinator) Shutdown() { close(rc.stopChan) + + // Cancel all pending cleanup timers + rc.cleanupTimerMu.Lock() + for _, timer := range rc.cleanupTimers { + timer.Stop() + } + rc.cleanupTimers = make(map[string]*time.Timer) + rc.cleanupTimerMu.Unlock() + rc.wg.Wait() } diff --git a/refresh_coordinator_test.go b/refresh_coordinator_test.go index a108847..ea26baf 100644 --- a/refresh_coordinator_test.go +++ b/refresh_coordinator_test.go @@ -671,3 +671,98 @@ func TestCleanupRoutine(t *testing.T) { t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount) } } + +// TestNoGoroutineExplosionWithTimers verifies that timer-based cleanup doesn't cause goroutine explosion +// This was the original issue: spawning a goroutine per refresh to sleep and cleanup +func TestNoGoroutineExplosionWithTimers(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.DeduplicationCleanupDelay = 100 * time.Millisecond // Non-zero delay + config.MaxConcurrentRefreshes = 100 // Allow many concurrent + config.MaxRefreshAttempts = 10000 // Don't rate limit + + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Record initial goroutines (allow settling time) + time.Sleep(50 * time.Millisecond) + runtime.GC() + initialGoroutines := runtime.NumGoroutine() + t.Logf("Initial goroutines: %d", initialGoroutines) + + // Submit many refresh operations rapidly + const numRefreshes = 500 + var wg sync.WaitGroup + wg.Add(numRefreshes) + + refreshFunc := func() (*TokenResponse, error) { + return &TokenResponse{AccessToken: "token"}, nil + } + + for i := 0; i < numRefreshes; i++ { + go func(id int) { + defer wg.Done() + ctx := context.Background() + _, _ = coordinator.CoordinateRefresh( + ctx, + fmt.Sprintf("session_%d", id), + fmt.Sprintf("token_%d", id), + refreshFunc, + ) + }(i) + } + + wg.Wait() + + // Measure goroutines immediately after all operations complete + // With the old approach, we'd have ~500 sleeping goroutines + // With the new timer approach, we should have much fewer + currentGoroutines := runtime.NumGoroutine() + t.Logf("Goroutines after %d refresh operations: %d", numRefreshes, currentGoroutines) + + // Check timer count + coordinator.cleanupTimerMu.Lock() + timerCount := len(coordinator.cleanupTimers) + coordinator.cleanupTimerMu.Unlock() + t.Logf("Active cleanup timers: %d", timerCount) + + // With timer-based cleanup, goroutine increase should be minimal + // Timers don't create goroutines - they use the runtime timer heap + goroutineIncrease := currentGoroutines - initialGoroutines + + // Allow for some goroutine overhead (test framework, etc) + // With the old approach, we'd see ~500 goroutines + // With the new approach, we should see <50 (much smaller) + maxAcceptableIncrease := 100 // Very generous limit + + if goroutineIncrease > maxAcceptableIncrease { + t.Errorf("Goroutine explosion detected: started with %d, now have %d (increase of %d)", + initialGoroutines, currentGoroutines, goroutineIncrease) + } + + // Wait for timers to fire and cleanup + time.Sleep(config.DeduplicationCleanupDelay + 50*time.Millisecond) + + // Verify timers were cleaned up + coordinator.cleanupTimerMu.Lock() + remainingTimers := len(coordinator.cleanupTimers) + coordinator.cleanupTimerMu.Unlock() + + // Most timers should have fired and been removed + if remainingTimers > 10 { + t.Errorf("Too many cleanup timers remaining: %d", remainingTimers) + } + + // Verify goroutines returned to near initial + runtime.GC() + time.Sleep(50 * time.Millisecond) + finalGoroutines := runtime.NumGoroutine() + t.Logf("Final goroutines: %d", finalGoroutines) + + // Should be close to initial (within tolerance) + finalIncrease := finalGoroutines - initialGoroutines + if finalIncrease > 20 { + t.Errorf("Goroutine leak detected: started with %d, ended with %d (increase of %d)", + initialGoroutines, finalGoroutines, finalIncrease) + } +} diff --git a/session_chunk_manager.go b/session_chunk_manager.go index 6d08bea..ff8d6bc 100644 --- a/session_chunk_manager.go +++ b/session_chunk_manager.go @@ -5,7 +5,6 @@ import ( "context" "encoding/base64" "fmt" - "runtime" "strings" "sync" "sync/atomic" @@ -196,26 +195,23 @@ func (cm *ChunkManager) performPeriodicCleanup() { cm.CleanupExpiredSessions() - // Force garbage collection if memory usage is high - var m runtime.MemStats - runtime.ReadMemStats(&m) - + // Track memory stats for monitoring but DO NOT force GC + // Forced GC causes significant CPU spikes every cleanup interval + // Let the Go runtime handle GC scheduling efficiently currentSessions := atomic.LoadInt64(&cm.peakSessions) allocatedBytes := atomic.LoadInt64(&cm.bytesAllocated) - if allocatedBytes > 10*1024*1024 || currentSessions > int64(cm.maxSessions/2) { - runtime.GC() - if cm.logger != nil { - cm.logger.Debugf("Forced GC: sessions=%d, allocated=%d bytes", - currentSessions, allocatedBytes) - } - } - duration := time.Since(startTime) atomic.AddInt64(&cm.cleanupCount, 1) - if cm.logger != nil && duration > 100*time.Millisecond { - cm.logger.Debugf("Chunk manager cleanup took %v", duration) + if cm.logger != nil { + if duration > 100*time.Millisecond { + cm.logger.Debugf("Chunk manager cleanup took %v (sessions=%d, allocated=%d bytes)", + duration, currentSessions, allocatedBytes) + } else if duration > 10*time.Millisecond { + cm.logger.Debugf("Chunk manager cleanup: sessions=%d, allocated=%d bytes, duration=%v", + currentSessions, allocatedBytes, duration) + } } } @@ -1161,6 +1157,7 @@ func (cm *ChunkManager) CleanupExpiredSessions(force ...bool) { } // enforceSessionLimit removes oldest sessions when limit is exceeded +// Uses partial sort (O(n) for finding k smallest) instead of full sort (O(n log n)) func (cm *ChunkManager) enforceSessionLimit() { currentLocal := len(cm.sessionMap) currentGlobal := atomic.LoadInt64(&globalSessionCount) @@ -1185,37 +1182,20 @@ func (cm *ChunkManager) enforceSessionLimit() { return } - // Find oldest sessions to remove - type sessionAge struct { - key string - lastUsed time.Time - } - - sessions := make([]sessionAge, 0, len(cm.sessionMap)) - for key, entry := range cm.sessionMap { - sessions = append(sessions, sessionAge{key: key, lastUsed: entry.LastUsed}) - } - - // Sort by last used time (oldest first) - for i := 0; i < len(sessions)-1; i++ { - for j := i + 1; j < len(sessions); j++ { - if sessions[i].lastUsed.After(sessions[j].lastUsed) { - sessions[i], sessions[j] = sessions[j], sessions[i] - } - } - } - - // Remove excess sessions and track memory - CRITICAL FIX: More aggressive + // Calculate how many sessions to remove excessCount := currentLocal - targetCapacity - if excessCount < 0 { - excessCount = 0 + if excessCount <= 0 { + return } + // Use partial selection instead of full sort for better performance + // For finding k oldest sessions, we only need O(n) operations + keysToRemove := cm.findOldestSessions(excessCount) + totalBytesFreed := int64(0) removedCount := int64(0) - for i := 0; i < excessCount && i < len(sessions); i++ { - key := sessions[i].key + for _, key := range keysToRemove { if entry, exists := cm.sessionMap[key]; exists { totalBytesFreed += entry.SizeEstimate atomic.AddInt64(&cm.bytesAllocated, -entry.SizeEstimate) @@ -1230,7 +1210,56 @@ func (cm *ChunkManager) enforceSessionLimit() { } cm.logger.Infof("Enforced session limit: removed %d excess sessions, freed %d bytes", - excessCount, totalBytesFreed) + len(keysToRemove), totalBytesFreed) +} + +// findOldestSessions returns keys of the k oldest sessions efficiently +// Uses a simple approach: find the kth oldest timestamp, then collect all older entries +func (cm *ChunkManager) findOldestSessions(k int) []string { + if k <= 0 || len(cm.sessionMap) == 0 { + return nil + } + + if k >= len(cm.sessionMap) { + // Remove all sessions + keys := make([]string, 0, len(cm.sessionMap)) + for key := range cm.sessionMap { + keys = append(keys, key) + } + return keys + } + + // Collect all timestamps with keys + type sessionAge struct { + key string + lastUsed time.Time + } + + sessions := make([]sessionAge, 0, len(cm.sessionMap)) + for key, entry := range cm.sessionMap { + sessions = append(sessions, sessionAge{key: key, lastUsed: entry.LastUsed}) + } + + // Partial sort: get the k smallest elements using selection + // This is O(n*k) which is better than O(n log n) when k << n + for i := 0; i < k; i++ { + minIdx := i + for j := i + 1; j < len(sessions); j++ { + if sessions[j].lastUsed.Before(sessions[minIdx].lastUsed) { + minIdx = j + } + } + if minIdx != i { + sessions[i], sessions[minIdx] = sessions[minIdx], sessions[i] + } + } + + // Return the k oldest + result := make([]string, k) + for i := 0; i < k; i++ { + result[i] = sessions[i].key + } + return result } // CanCreateSession checks if a new session can be created within limits @@ -1292,29 +1321,12 @@ func (cm *ChunkManager) EmergencyCleanup() { // If still over 80% capacity, remove oldest sessions more aggressively targetCapacity := int(float64(cm.maxSessions) * 0.8) if len(cm.sessionMap) > targetCapacity { - type sessionAge struct { - key string - lastUsed time.Time - } - - sessions := make([]sessionAge, 0, len(cm.sessionMap)) - for key, entry := range cm.sessionMap { - sessions = append(sessions, sessionAge{key: key, lastUsed: entry.LastUsed}) - } - - // Sort by last used time (oldest first) - for i := 0; i < len(sessions)-1; i++ { - for j := i + 1; j < len(sessions); j++ { - if sessions[i].lastUsed.After(sessions[j].lastUsed) { - sessions[i], sessions[j] = sessions[j], sessions[i] - } - } - } - - // Remove sessions until we reach target capacity excessCount := len(cm.sessionMap) - targetCapacity - for i := 0; i < excessCount && i < len(sessions); i++ { - key := sessions[i].key + + // Use efficient partial sort to find oldest sessions + keysToRemove := cm.findOldestSessions(excessCount) + + for _, key := range keysToRemove { if entry, exists := cm.sessionMap[key]; exists { atomic.AddInt64(&cm.bytesAllocated, -entry.SizeEstimate) } @@ -1327,11 +1339,9 @@ func (cm *ChunkManager) EmergencyCleanup() { cm.logger.Infof("Emergency cleanup completed: removed %d sessions, %d remaining", removed, len(cm.sessionMap)) - // Log memory stats after emergency cleanup - var m runtime.MemStats - runtime.ReadMemStats(&m) - cm.logger.Infof("Memory after emergency cleanup - Heap: %.1fMB, Sessions: %d, Tracked bytes: %d", - float64(m.HeapAlloc)/(1024*1024), len(cm.sessionMap), atomic.LoadInt64(&cm.bytesAllocated)) + // Log memory stats after emergency cleanup (read only, no forced GC) + cm.logger.Infof("Sessions after emergency cleanup: %d, Tracked bytes: %d", + len(cm.sessionMap), atomic.LoadInt64(&cm.bytesAllocated)) } // GetSessionCount returns the current number of active sessions (for monitoring) diff --git a/settings.go b/settings.go index be82479..a80b5a6 100644 --- a/settings.go +++ b/settings.go @@ -89,6 +89,91 @@ type Config struct { // Recommended: true for multi-replica deployments DisableReplayDetection bool `json:"disableReplayDetection,omitempty"` SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"` + + // DynamicClientRegistration enables OIDC Dynamic Client Registration (RFC 7591) + // When enabled, the middleware will automatically register as a client with + // the OIDC provider if ClientID/ClientSecret are not provided. + DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"` +} + +// DynamicClientRegistrationConfig configures OIDC Dynamic Client Registration (RFC 7591) +type DynamicClientRegistrationConfig struct { + // Enabled enables automatic client registration with the OIDC provider + Enabled bool `json:"enabled"` + + // InitialAccessToken is an optional bearer token for protected registration endpoints + // Some providers require this token to authorize new client registrations + InitialAccessToken string `json:"initialAccessToken,omitempty"` + + // RegistrationEndpoint overrides the endpoint discovered from provider metadata + // If empty, uses the registration_endpoint from .well-known/openid-configuration + RegistrationEndpoint string `json:"registrationEndpoint,omitempty"` + + // ClientMetadata contains the client metadata to register + ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"` + + // PersistCredentials determines whether to save registered credentials to a file + // This allows reusing the same client_id/client_secret across restarts + PersistCredentials bool `json:"persistCredentials"` + + // CredentialsFile is the path to store/load registered client credentials + // Defaults to "/tmp/oidc-client-credentials.json" if not specified + CredentialsFile string `json:"credentialsFile,omitempty"` +} + +// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591) +type ClientRegistrationMetadata struct { + // RedirectURIs is REQUIRED - array of redirect URIs for authorization + RedirectURIs []string `json:"redirect_uris"` + + // ResponseTypes specifies OAuth 2.0 response types (default: ["code"]) + ResponseTypes []string `json:"response_types,omitempty"` + + // GrantTypes specifies OAuth 2.0 grant types (default: ["authorization_code"]) + GrantTypes []string `json:"grant_types,omitempty"` + + // ApplicationType is either "web" (default) or "native" + ApplicationType string `json:"application_type,omitempty"` + + // Contacts is an array of email addresses for responsible parties + Contacts []string `json:"contacts,omitempty"` + + // ClientName is a human-readable name for the client + ClientName string `json:"client_name,omitempty"` + + // LogoURI is a URL pointing to a logo for the client + LogoURI string `json:"logo_uri,omitempty"` + + // ClientURI is a URL of the home page of the client + ClientURI string `json:"client_uri,omitempty"` + + // PolicyURI is a URL pointing to the client's privacy policy + PolicyURI string `json:"policy_uri,omitempty"` + + // TOSURI is a URL pointing to the client's terms of service + TOSURI string `json:"tos_uri,omitempty"` + + // JWKSURI is a URL for the client's JSON Web Key Set + JWKSURI string `json:"jwks_uri,omitempty"` + + // SubjectType is "pairwise" or "public" (provider-specific) + SubjectType string `json:"subject_type,omitempty"` + + // TokenEndpointAuthMethod specifies how the client authenticates at token endpoint + // Values: "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", "none" + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + + // DefaultMaxAge is the default maximum authentication age in seconds + DefaultMaxAge int `json:"default_max_age,omitempty"` + + // RequireAuthTime specifies whether auth_time claim is required in ID token + RequireAuthTime bool `json:"require_auth_time,omitempty"` + + // DefaultACRValues specifies default ACR values + DefaultACRValues []string `json:"default_acr_values,omitempty"` + + // Scope is a space-separated list of scopes (alternative to config.Scopes) + Scope string `json:"scope,omitempty"` } // SecurityHeadersConfig configures security headers for the plugin diff --git a/sharded_cache.go b/sharded_cache.go new file mode 100644 index 0000000..a533114 --- /dev/null +++ b/sharded_cache.go @@ -0,0 +1,207 @@ +package traefikoidc + +import ( + "hash/fnv" + "sync" + "time" +) + +// ShardedCache provides a thread-safe cache with sharded locks to reduce contention. +// Instead of a single global mutex, it distributes entries across multiple shards, +// each with its own mutex. This dramatically reduces lock contention under high load. +type ShardedCache struct { + shards []*cacheShard + numShards uint32 + maxPerShard int +} + +// cacheShard represents a single shard with its own mutex and data map. +type cacheShard struct { + mu sync.RWMutex + items map[string]*shardedCacheItem +} + +// shardedCacheItem represents an item in the sharded cache with expiration. +type shardedCacheItem struct { + value interface{} + expiresAt time.Time +} + +// NewShardedCache creates a new sharded cache with the specified number of shards. +// More shards = less contention but more memory overhead. +// Recommended: 32-256 shards depending on expected concurrency. +func NewShardedCache(numShards int, maxSize int) *ShardedCache { + if numShards <= 0 { + numShards = 64 // Default to 64 shards + } + if maxSize <= 0 { + maxSize = 10000 // Default max size + } + + shards := make([]*cacheShard, numShards) + maxPerShard := maxSize / numShards + if maxPerShard < 100 { + maxPerShard = 100 // Minimum 100 per shard + } + + for i := 0; i < numShards; i++ { + shards[i] = &cacheShard{ + items: make(map[string]*shardedCacheItem), + } + } + + return &ShardedCache{ + shards: shards, + numShards: uint32(numShards), + maxPerShard: maxPerShard, + } +} + +// getShard returns the shard for a given key using FNV-1a hash. +// FNV-1a is fast and provides good distribution. +func (c *ShardedCache) getShard(key string) *cacheShard { + h := fnv.New32a() + h.Write([]byte(key)) + return c.shards[h.Sum32()%c.numShards] +} + +// Get retrieves an item from the cache. +// Returns the value and true if found and not expired, nil and false otherwise. +func (c *ShardedCache) Get(key string) (interface{}, bool) { + shard := c.getShard(key) + shard.mu.RLock() + item, exists := shard.items[key] + shard.mu.RUnlock() + + if !exists { + return nil, false + } + + // Check expiration + if !item.expiresAt.IsZero() && time.Now().After(item.expiresAt) { + // Item expired - remove it lazily + c.Delete(key) + return nil, false + } + + return item.value, true +} + +// Set adds or updates an item in the cache with a TTL. +// If ttl is 0 or negative, the item never expires. +func (c *ShardedCache) Set(key string, value interface{}, ttl time.Duration) { + shard := c.getShard(key) + + var expiresAt time.Time + if ttl > 0 { + expiresAt = time.Now().Add(ttl) + } + + shard.mu.Lock() + // Check if we need to evict items + if len(shard.items) >= c.maxPerShard { + // Simple eviction: remove expired items first, then oldest + c.evictFromShardLocked(shard) + } + + shard.items[key] = &shardedCacheItem{ + value: value, + expiresAt: expiresAt, + } + shard.mu.Unlock() +} + +// Delete removes an item from the cache. +func (c *ShardedCache) Delete(key string) { + shard := c.getShard(key) + shard.mu.Lock() + delete(shard.items, key) + shard.mu.Unlock() +} + +// Exists checks if a key exists in the cache and is not expired. +func (c *ShardedCache) Exists(key string) bool { + _, exists := c.Get(key) + return exists +} + +// evictFromShardLocked removes expired items from a shard. +// Must be called with shard.mu held. +func (c *ShardedCache) evictFromShardLocked(shard *cacheShard) { + now := time.Now() + evicted := 0 + maxEvict := len(shard.items) / 4 // Evict up to 25% of items + if maxEvict < 10 { + maxEvict = 10 + } + + // First pass: remove expired items + for key, item := range shard.items { + if !item.expiresAt.IsZero() && now.After(item.expiresAt) { + delete(shard.items, key) + evicted++ + if evicted >= maxEvict { + return + } + } + } + + // If still over capacity, remove some items (FIFO approximation via map iteration) + // This is an approximation since Go maps don't maintain insertion order + remaining := len(shard.items) - c.maxPerShard + 10 // Leave some headroom + if remaining > 0 { + for key := range shard.items { + delete(shard.items, key) + remaining-- + if remaining <= 0 { + break + } + } + } +} + +// Cleanup removes all expired items from all shards. +// Call this periodically to prevent memory growth. +func (c *ShardedCache) Cleanup() { + now := time.Now() + for _, shard := range c.shards { + shard.mu.Lock() + for key, item := range shard.items { + if !item.expiresAt.IsZero() && now.After(item.expiresAt) { + delete(shard.items, key) + } + } + shard.mu.Unlock() + } +} + +// Size returns the total number of items across all shards. +func (c *ShardedCache) Size() int { + total := 0 + for _, shard := range c.shards { + shard.mu.RLock() + total += len(shard.items) + shard.mu.RUnlock() + } + return total +} + +// Clear removes all items from all shards. +func (c *ShardedCache) Clear() { + for _, shard := range c.shards { + shard.mu.Lock() + shard.items = make(map[string]*shardedCacheItem) + shard.mu.Unlock() + } +} + +// ShardStats returns statistics about each shard for debugging/monitoring. +func (c *ShardedCache) ShardStats() []int { + stats := make([]int, len(c.shards)) + for i, shard := range c.shards { + shard.mu.RLock() + stats[i] = len(shard.items) + shard.mu.RUnlock() + } + return stats +} diff --git a/sharded_cache_test.go b/sharded_cache_test.go new file mode 100644 index 0000000..6de8419 --- /dev/null +++ b/sharded_cache_test.go @@ -0,0 +1,413 @@ +package traefikoidc + +import ( + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestShardedCacheBasicOperations(t *testing.T) { + t.Run("SetAndGet", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + cache.Set("key1", "value1", 5*time.Minute) + cache.Set("key2", 42, 5*time.Minute) + cache.Set("key3", true, 5*time.Minute) + + val1, ok := cache.Get("key1") + if !ok || val1 != "value1" { + t.Errorf("Expected 'value1', got %v, ok=%v", val1, ok) + } + + val2, ok := cache.Get("key2") + if !ok || val2 != 42 { + t.Errorf("Expected 42, got %v, ok=%v", val2, ok) + } + + val3, ok := cache.Get("key3") + if !ok || val3 != true { + t.Errorf("Expected true, got %v, ok=%v", val3, ok) + } + }) + + t.Run("GetNonExistent", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + val, ok := cache.Get("nonexistent") + if ok || val != nil { + t.Errorf("Expected nil/false for nonexistent key, got %v/%v", val, ok) + } + }) + + t.Run("Delete", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + cache.Set("key1", "value1", 5*time.Minute) + cache.Delete("key1") + + val, ok := cache.Get("key1") + if ok || val != nil { + t.Errorf("Expected nil/false after delete, got %v/%v", val, ok) + } + }) + + t.Run("Exists", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + cache.Set("key1", "value1", 5*time.Minute) + + if !cache.Exists("key1") { + t.Error("Expected Exists to return true for existing key") + } + + if cache.Exists("nonexistent") { + t.Error("Expected Exists to return false for nonexistent key") + } + }) + + t.Run("Size", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + if cache.Size() != 0 { + t.Errorf("Expected size 0, got %d", cache.Size()) + } + + for i := 0; i < 100; i++ { + cache.Set(fmt.Sprintf("key%d", i), i, 5*time.Minute) + } + + if cache.Size() != 100 { + t.Errorf("Expected size 100, got %d", cache.Size()) + } + }) + + t.Run("Clear", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + for i := 0; i < 100; i++ { + cache.Set(fmt.Sprintf("key%d", i), i, 5*time.Minute) + } + + cache.Clear() + + if cache.Size() != 0 { + t.Errorf("Expected size 0 after clear, got %d", cache.Size()) + } + }) +} + +func TestShardedCacheExpiration(t *testing.T) { + t.Run("ItemExpires", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + cache.Set("key1", "value1", 50*time.Millisecond) + + // Should exist immediately + if !cache.Exists("key1") { + t.Error("Item should exist immediately after set") + } + + // Wait for expiration + time.Sleep(100 * time.Millisecond) + + // Should be expired now + if cache.Exists("key1") { + t.Error("Item should have expired") + } + }) + + t.Run("CleanupRemovesExpired", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + // Add items with short TTL + for i := 0; i < 50; i++ { + cache.Set(fmt.Sprintf("expired%d", i), i, 10*time.Millisecond) + } + + // Add items with long TTL + for i := 0; i < 50; i++ { + cache.Set(fmt.Sprintf("valid%d", i), i, 5*time.Minute) + } + + // Wait for short-TTL items to expire + time.Sleep(50 * time.Millisecond) + + // Run cleanup + cache.Cleanup() + + // Should have only valid items + // Note: Size still includes expired items until Get/Cleanup removes them + // So we check by accessing items + for i := 0; i < 50; i++ { + if cache.Exists(fmt.Sprintf("expired%d", i)) { + t.Errorf("Expired item %d should not exist after cleanup", i) + } + } + + for i := 0; i < 50; i++ { + if !cache.Exists(fmt.Sprintf("valid%d", i)) { + t.Errorf("Valid item %d should still exist after cleanup", i) + } + } + }) + + t.Run("ZeroTTLNeverExpires", func(t *testing.T) { + cache := NewShardedCache(16, 1000) + + cache.Set("permanent", "value", 0) + + time.Sleep(10 * time.Millisecond) + + if !cache.Exists("permanent") { + t.Error("Item with 0 TTL should never expire") + } + }) +} + +func TestShardedCacheConcurrency(t *testing.T) { + t.Run("ConcurrentSetGet", func(t *testing.T) { + cache := NewShardedCache(64, 10000) + const numGoroutines = 100 + const numOperations = 1000 + + var wg sync.WaitGroup + var errors int32 + + // Concurrent writers + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("key-%d-%d", id, j) + cache.Set(key, j, 5*time.Minute) + } + }(i) + } + + // Concurrent readers + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("key-%d-%d", id, j) + cache.Get(key) + } + }(i) + } + + wg.Wait() + + if atomic.LoadInt32(&errors) > 0 { + t.Errorf("Encountered %d errors during concurrent access", errors) + } + }) + + t.Run("ConcurrentMixedOperations", func(t *testing.T) { + cache := NewShardedCache(64, 10000) + const numGoroutines = 50 + const numOperations = 500 + + var wg sync.WaitGroup + + // Mix of operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("key-%d", j%100) // Overlapping keys + switch j % 4 { + case 0: + cache.Set(key, j, 5*time.Minute) + case 1: + cache.Get(key) + case 2: + cache.Exists(key) + case 3: + cache.Delete(key) + } + } + }(i) + } + + wg.Wait() + }) + + t.Run("NoConcurrentPanics", func(t *testing.T) { + cache := NewShardedCache(32, 5000) + const numGoroutines = 100 + + var wg sync.WaitGroup + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + t.Errorf("Panic in goroutine %d: %v", id, r) + } + }() + + for j := 0; j < 100; j++ { + cache.Set(fmt.Sprintf("k%d", j), j, time.Millisecond) + cache.Get(fmt.Sprintf("k%d", j)) + cache.Cleanup() + } + }(i) + } + + wg.Wait() + }) +} + +func TestShardedCacheEviction(t *testing.T) { + t.Run("EvictsWhenFull", func(t *testing.T) { + // Small cache to trigger eviction - 4 shards with max 100 per shard minimum + // With our implementation, maxPerShard defaults to at least 100 + cache := NewShardedCache(4, 100) + + // Fill well beyond capacity to trigger eviction + for i := 0; i < 600; i++ { + cache.Set(fmt.Sprintf("key%d", i), i, 5*time.Minute) + } + + // Should have evicted some items - eviction happens when shard reaches maxPerShard + size := cache.Size() + // With 4 shards and 100 per shard minimum, max should be ~400 + // We added 600, so some should be evicted + if size >= 600 { + t.Errorf("Expected eviction to reduce size below 600, got %d", size) + } + t.Logf("Cache size after adding 600 items: %d", size) + }) + + t.Run("EvictsExpiredFirst", func(t *testing.T) { + cache := NewShardedCache(4, 100) + + // Add expired items first + for i := 0; i < 50; i++ { + cache.Set(fmt.Sprintf("expired%d", i), i, 1*time.Millisecond) + } + + time.Sleep(10 * time.Millisecond) // Let them expire + + // Add valid items + for i := 0; i < 100; i++ { + cache.Set(fmt.Sprintf("valid%d", i), i, 5*time.Minute) + } + + // Valid items should mostly still exist + validCount := 0 + for i := 0; i < 100; i++ { + if cache.Exists(fmt.Sprintf("valid%d", i)) { + validCount++ + } + } + + // Should have most valid items (at least 80%) + if validCount < 80 { + t.Errorf("Expected at least 80 valid items, got %d", validCount) + } + }) +} + +func TestShardedCacheShardDistribution(t *testing.T) { + t.Run("EvenDistribution", func(t *testing.T) { + cache := NewShardedCache(16, 16000) + + // Add many items + for i := 0; i < 10000; i++ { + cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute) + } + + stats := cache.ShardStats() + + // Check for reasonable distribution (no shard should have > 2x average) + average := 10000 / 16 + for i, count := range stats { + if count > average*3 || count < average/3 { + t.Errorf("Shard %d has uneven distribution: %d items (expected ~%d)", i, count, average) + } + } + }) +} + +// BenchmarkShardedCache benchmarks the sharded cache operations +func BenchmarkShardedCache(b *testing.B) { + b.Run("Set", func(b *testing.B) { + cache := NewShardedCache(64, 100000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute) + } + }) + + b.Run("Get", func(b *testing.B) { + cache := NewShardedCache(64, 100000) + for i := 0; i < 10000; i++ { + cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache.Get(fmt.Sprintf("key-%d", i%10000)) + } + }) + + b.Run("ParallelSetGet", func(b *testing.B) { + cache := NewShardedCache(64, 100000) + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := fmt.Sprintf("key-%d", i) + cache.Set(key, i, 5*time.Minute) + cache.Get(key) + i++ + } + }) + }) +} + +// BenchmarkShardedVsGlobalMutex compares sharded cache with global mutex approach +func BenchmarkShardedVsGlobalMutex(b *testing.B) { + b.Run("ShardedCache64", func(b *testing.B) { + cache := NewShardedCache(64, 100000) + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := fmt.Sprintf("jti-%d", i%10000) + if !cache.Exists(key) { + cache.Set(key, true, 5*time.Minute) + } + i++ + } + }) + }) + + b.Run("GlobalMutexCache", func(b *testing.B) { + var mu sync.RWMutex + data := make(map[string]bool) + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := fmt.Sprintf("jti-%d", i%10000) + + mu.RLock() + _, exists := data[key] + mu.RUnlock() + + if !exists { + mu.Lock() + data[key] = true + mu.Unlock() + } + i++ + } + }) + }) +} diff --git a/singleton_resources.go b/singleton_resources.go index ce3a3b2..ad5143d 100644 --- a/singleton_resources.go +++ b/singleton_resources.go @@ -345,6 +345,10 @@ type GoroutinePool struct { shutdownChan chan struct{} logger *Logger started int32 + + // Condition variable for efficient Wait() without busy-polling + taskCond *sync.Cond + pendingTasks int64 // atomic counter for pending tasks } // NewGoroutinePool creates a new goroutine pool with the specified max workers @@ -354,6 +358,8 @@ func NewGoroutinePool(maxWorkers int, logger *Logger) *GoroutinePool { taskQueue: make(chan func(), maxWorkers*2), // Buffer for queuing shutdownChan: make(chan struct{}), logger: logger, + taskCond: sync.NewCond(&sync.Mutex{}), + pendingTasks: 0, } // Start workers @@ -390,6 +396,14 @@ func (p *GoroutinePool) worker(id int) { }() task() }() + + // Signal that task is complete - decrement pending count and notify waiters + newCount := atomic.AddInt64(&p.pendingTasks, -1) + if newCount == 0 { + p.taskCond.L.Lock() + p.taskCond.Broadcast() // Wake up all waiters when queue is empty + p.taskCond.L.Unlock() + } } case <-p.shutdownChan: if p.logger != nil { @@ -406,10 +420,15 @@ func (p *GoroutinePool) Submit(task func()) error { return fmt.Errorf("pool is shutdown") } + // Increment pending task count BEFORE queuing to avoid race with Wait() + atomic.AddInt64(&p.pendingTasks, 1) + select { case p.taskQueue <- task: return nil case <-p.shutdownChan: + // Decrement since task won't be processed + atomic.AddInt64(&p.pendingTasks, -1) return fmt.Errorf("pool is shutting down") default: // Queue is full, try with a small timeout @@ -417,21 +436,53 @@ func (p *GoroutinePool) Submit(task func()) error { case p.taskQueue <- task: return nil case <-time.After(100 * time.Millisecond): + // Decrement since task won't be processed + atomic.AddInt64(&p.pendingTasks, -1) return fmt.Errorf("task queue is full") case <-p.shutdownChan: + // Decrement since task won't be processed + atomic.AddInt64(&p.pendingTasks, -1) return fmt.Errorf("pool is shutting down") } } } -// Wait waits for all submitted tasks to complete +// Wait waits for all submitted tasks to complete using condition variable +// This is efficient and does not busy-poll, avoiding CPU spikes func (p *GoroutinePool) Wait() { - // Drain the task queue - for len(p.taskQueue) > 0 { - time.Sleep(10 * time.Millisecond) + p.taskCond.L.Lock() + defer p.taskCond.L.Unlock() + + // Wait until all pending tasks are complete + // Uses condition variable to sleep efficiently instead of busy-polling + for atomic.LoadInt64(&p.pendingTasks) > 0 { + p.taskCond.Wait() // Efficiently blocks until signaled } } +// WaitWithTimeout waits for all submitted tasks to complete with a timeout +// Returns true if all tasks completed, false if timeout occurred +func (p *GoroutinePool) WaitWithTimeout(timeout time.Duration) bool { + done := make(chan struct{}) + + go func() { + p.Wait() + close(done) + }() + + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +// PendingTasks returns the number of tasks currently pending (queued or in-progress) +func (p *GoroutinePool) PendingTasks() int64 { + return atomic.LoadInt64(&p.pendingTasks) +} + // Shutdown gracefully shuts down the pool func (p *GoroutinePool) Shutdown(ctx context.Context) error { var err error diff --git a/singleton_resources_test.go b/singleton_resources_test.go index 4ce8c8d..af63ace 100644 --- a/singleton_resources_test.go +++ b/singleton_resources_test.go @@ -505,6 +505,285 @@ func TestBackwardCompatibility(t *testing.T) { }) } +// TestGoroutinePoolConditionVariable tests the condition variable-based Wait implementation +func TestGoroutinePoolConditionVariable(t *testing.T) { + t.Run("WaitDoesNotBusyPoll", func(t *testing.T) { + // This test verifies that Wait() uses condition variable instead of busy-polling + pool := NewGoroutinePool(2, nil) + defer pool.Shutdown(context.Background()) + + // Submit a slow task + var taskStarted, taskFinished int32 + pool.Submit(func() { + atomic.StoreInt32(&taskStarted, 1) + time.Sleep(100 * time.Millisecond) + atomic.StoreInt32(&taskFinished, 1) + }) + + // Give task time to start + time.Sleep(10 * time.Millisecond) + + // Measure CPU-time before Wait + startCPU := time.Now() + + // Wait should block efficiently without consuming CPU + pool.Wait() + + elapsed := time.Since(startCPU) + + // Verify task completed + if atomic.LoadInt32(&taskFinished) != 1 { + t.Error("Task should have finished") + } + + // Wait should have taken ~90ms (task was already running for ~10ms) + // If it was busy-polling, we would see much higher CPU usage + // This is a sanity check - the real proof is in profiling + if elapsed < 50*time.Millisecond { + t.Errorf("Wait returned too quickly: %v", elapsed) + } + }) + + t.Run("WaitReturnsImmediatelyWhenEmpty", func(t *testing.T) { + pool := NewGoroutinePool(2, nil) + defer pool.Shutdown(context.Background()) + + // Wait on empty pool should return immediately + start := time.Now() + pool.Wait() + elapsed := time.Since(start) + + // Should return almost immediately + if elapsed > 10*time.Millisecond { + t.Errorf("Wait on empty pool took too long: %v", elapsed) + } + }) + + t.Run("ConcurrentSubmitAndWait", func(t *testing.T) { + pool := NewGoroutinePool(4, nil) + defer pool.Shutdown(context.Background()) + + var completed int32 + const numTasks = 100 + + // Submit tasks concurrently + var wg sync.WaitGroup + for i := 0; i < numTasks; i++ { + wg.Add(1) + go func() { + defer wg.Done() + pool.Submit(func() { + time.Sleep(1 * time.Millisecond) + atomic.AddInt32(&completed, 1) + }) + }() + } + + wg.Wait() // Wait for all submissions + + // Wait for all tasks to complete + pool.Wait() + + if atomic.LoadInt32(&completed) != numTasks { + t.Errorf("Expected %d tasks completed, got %d", numTasks, completed) + } + }) + + t.Run("WaitWithTimeoutSuccess", func(t *testing.T) { + pool := NewGoroutinePool(2, nil) + defer pool.Shutdown(context.Background()) + + pool.Submit(func() { + time.Sleep(50 * time.Millisecond) + }) + + // Should complete within timeout + success := pool.WaitWithTimeout(1 * time.Second) + if !success { + t.Error("WaitWithTimeout should have succeeded") + } + }) + + t.Run("WaitWithTimeoutExpired", func(t *testing.T) { + pool := NewGoroutinePool(1, nil) + defer pool.Shutdown(context.Background()) + + pool.Submit(func() { + time.Sleep(500 * time.Millisecond) + }) + + // Should timeout + success := pool.WaitWithTimeout(50 * time.Millisecond) + if success { + t.Error("WaitWithTimeout should have timed out") + } + + // Wait for actual completion to avoid goroutine leak in test + pool.Wait() + }) + + t.Run("PendingTasksCounter", func(t *testing.T) { + // Use pool with larger buffer (maxWorkers=2, buffer=4) + pool := NewGoroutinePool(2, nil) + defer pool.Shutdown(context.Background()) + + // Initially no pending tasks + if pool.PendingTasks() != 0 { + t.Errorf("Expected 0 pending tasks, got %d", pool.PendingTasks()) + } + + // Block both workers with signals that tasks have started + blocker1 := make(chan struct{}) + blocker2 := make(chan struct{}) + started1 := make(chan struct{}) + started2 := make(chan struct{}) + + pool.Submit(func() { + close(started1) + <-blocker1 + }) + pool.Submit(func() { + close(started2) + <-blocker2 + }) + + // Wait for both blocking tasks to actually start + <-started1 + <-started2 + + // Submit 2 more tasks that will queue up (buffer can hold 4) + for i := 0; i < 2; i++ { + pool.Submit(func() { + time.Sleep(1 * time.Millisecond) + }) + } + + // Should have pending tasks (2 running + 2 queued = 4) + pending := pool.PendingTasks() + if pending != 4 { + t.Errorf("Expected 4 pending tasks, got %d", pending) + } + + // Release blockers + close(blocker1) + close(blocker2) + + // Wait for completion + pool.Wait() + + // Should have no pending tasks + if pool.PendingTasks() != 0 { + t.Errorf("Expected 0 pending tasks after Wait, got %d", pool.PendingTasks()) + } + }) + + t.Run("MultipleWaiters", func(t *testing.T) { + pool := NewGoroutinePool(2, nil) + defer pool.Shutdown(context.Background()) + + // Submit a slow task + pool.Submit(func() { + time.Sleep(100 * time.Millisecond) + }) + + // Multiple goroutines waiting + var waiters sync.WaitGroup + var waitCount int32 + for i := 0; i < 5; i++ { + waiters.Add(1) + go func() { + defer waiters.Done() + pool.Wait() + atomic.AddInt32(&waitCount, 1) + }() + } + + // All waiters should complete + waiters.Wait() + + if atomic.LoadInt32(&waitCount) != 5 { + t.Errorf("Expected all 5 waiters to complete, got %d", waitCount) + } + }) + + t.Run("SubmitFailureDoesNotIncrementPending", func(t *testing.T) { + pool := NewGoroutinePool(1, nil) + + // Shutdown the pool + pool.Shutdown(context.Background()) + + // Submit should fail + err := pool.Submit(func() {}) + if err == nil { + t.Error("Submit should fail on shutdown pool") + } + + // Pending tasks should still be 0 + if pool.PendingTasks() != 0 { + t.Errorf("Pending tasks should be 0 after failed submit, got %d", pool.PendingTasks()) + } + }) + + t.Run("PanicRecoveryDecrementsPending", func(t *testing.T) { + pool := NewGoroutinePool(2, nil) + defer pool.Shutdown(context.Background()) + + // Submit a task that panics + pool.Submit(func() { + panic("test panic") + }) + + // Submit a normal task + var normalCompleted int32 + pool.Submit(func() { + atomic.StoreInt32(&normalCompleted, 1) + }) + + // Wait should still work (panic is recovered) + pool.Wait() + + // Normal task should have completed + if atomic.LoadInt32(&normalCompleted) != 1 { + t.Error("Normal task should have completed despite panic in other task") + } + + // Pending should be 0 + if pool.PendingTasks() != 0 { + t.Errorf("Pending tasks should be 0 after Wait, got %d", pool.PendingTasks()) + } + }) +} + +// BenchmarkGoroutinePoolWait benchmarks the Wait implementation +func BenchmarkGoroutinePoolWait(b *testing.B) { + pool := NewGoroutinePool(4, nil) + defer pool.Shutdown(context.Background()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Submit a quick task + pool.Submit(func() {}) + pool.Wait() + } +} + +// BenchmarkGoroutinePoolHighThroughput benchmarks high throughput scenario +func BenchmarkGoroutinePoolHighThroughput(b *testing.B) { + pool := NewGoroutinePool(8, nil) + defer pool.Shutdown(context.Background()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < 100; j++ { + pool.Submit(func() { + // Minimal work + _ = 1 + 1 + }) + } + pool.Wait() + } +} + // Helper function to reset singleton for testing func resetResourceManagerForTesting() { resourceManagerMutex.Lock() diff --git a/token_manager.go b/token_manager.go index f99addb..dd90646 100644 --- a/token_manager.go +++ b/token_manager.go @@ -122,15 +122,22 @@ func (t *TraefikOidc) VerifyToken(token string) error { t.safeLogErrorf("Token blacklist not available, skipping JTI %s blacklist", jti) } - replayCacheMu.Lock() - if replayCache == nil { - initReplayCache() - } + // Use sharded cache for replay detection - no global mutex needed + // This reduces lock contention by ~64x under high load + initReplayCache() duration := time.Until(expiry) if duration > 0 { - replayCache.Set(jti, true, duration) + if shardedReplayCache != nil { + shardedReplayCache.Set(jti, true, duration) + } else { + // Fall back to legacy cache (should rarely happen) + replayCacheMu.Lock() + if replayCache != nil { + replayCache.Set(jti, true, duration) + } + replayCacheMu.Unlock() + } } - replayCacheMu.Unlock() } return nil diff --git a/types.go b/types.go index cb8168c..1e77cc6 100644 --- a/types.go +++ b/types.go @@ -57,6 +57,7 @@ type ProviderMetadata struct { EndSessionURL string `json:"end_session_endpoint"` IntrospectionURL string `json:"introspection_endpoint,omitempty"` // OAuth 2.0 Token Introspection (RFC 7662) ScopesSupported []string `json:"scopes_supported,omitempty"` // Supported scopes from discovery + RegistrationURL string `json:"registration_endpoint,omitempty"` // OIDC Dynamic Client Registration (RFC 7591) } // TraefikOidc is the main middleware struct that implements OIDC authentication for Traefik. @@ -128,4 +129,9 @@ type TraefikOidc struct { securityHeadersApplier func(http.ResponseWriter, *http.Request) scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering scopesSupported []string // NEW - from provider metadata + + // Dynamic Client Registration (RFC 7591) + dynamicClientRegistrar *DynamicClientRegistrar + dcrConfig *DynamicClientRegistrationConfig + registrationURL string // OIDC Dynamic Client Registration endpoint } diff --git a/universal_cache.go b/universal_cache.go index d2b5ea1..8da349b 100644 --- a/universal_cache.go +++ b/universal_cache.go @@ -34,6 +34,11 @@ type UniversalCacheConfig struct { Logger *Logger Strategy CacheStrategy // For backward compatibility + // SkipAutoCleanup skips starting the per-cache cleanup goroutine. + // Use this when cleanup is managed externally (e.g., by UniversalCacheManager) + // to reduce goroutine count and consolidate cleanup operations. + SkipAutoCleanup bool + // Type-specific configurations TokenConfig *TokenCacheConfig MetadataConfig *MetadataCacheConfig @@ -143,8 +148,12 @@ func createUniversalCache(config UniversalCacheConfig) *UniversalCache { cancel: cancel, } - // Start cleanup routine - cache.startCleanup() + // Start cleanup routine only if not skipped + // When cleanup is managed externally (e.g., by UniversalCacheManager), + // skip per-cache cleanup to reduce goroutine count + if !config.SkipAutoCleanup { + cache.startCleanup() + } return cache } diff --git a/universal_cache_singleton.go b/universal_cache_singleton.go index 9d617d7..16453b9 100644 --- a/universal_cache_singleton.go +++ b/universal_cache_singleton.go @@ -1,11 +1,14 @@ package traefikoidc import ( + "context" "sync" "time" ) // UniversalCacheManager manages all cache instances using the universal cache +// It runs a single consolidated cleanup goroutine for all caches, reducing +// goroutine count and CPU overhead compared to per-cache cleanup routines. type UniversalCacheManager struct { tokenCache *UniversalCache blacklistCache *UniversalCache @@ -16,6 +19,12 @@ type UniversalCacheManager struct { tokenTypeCache *UniversalCache // Cache for token type detection results mu sync.RWMutex logger *Logger + + // Consolidated cleanup management + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + cleanupStarted bool } var ( @@ -30,25 +39,34 @@ func GetUniversalCacheManager(logger *Logger) *UniversalCacheManager { logger = GetSingletonNoOpLogger() } + ctx, cancel := context.WithCancel(context.Background()) + universalCacheManager = &UniversalCacheManager{ logger: logger, + ctx: ctx, + cancel: cancel, } + // Initialize all caches with SkipAutoCleanup=true to prevent 7 separate cleanup goroutines + // Instead, we use a single consolidated cleanup routine managed by this manager + // Initialize token cache - CRITICAL FIX: Reduced from 5000 to 1000 universalCacheManager.tokenCache = NewUniversalCache(UniversalCacheConfig{ - Type: CacheTypeToken, - MaxSize: 1000, // CRITICAL FIX: Reduced from 5000 to 1000 items - MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit - DefaultTTL: 1 * time.Hour, - Logger: logger, + Type: CacheTypeToken, + MaxSize: 1000, // CRITICAL FIX: Reduced from 5000 to 1000 items + MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit + DefaultTTL: 1 * time.Hour, + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup }) // Initialize blacklist cache universalCacheManager.blacklistCache = NewUniversalCache(UniversalCacheConfig{ - Type: CacheTypeToken, - MaxSize: 1000, - DefaultTTL: 24 * time.Hour, - Logger: logger, + Type: CacheTypeToken, + MaxSize: 1000, + DefaultTTL: 24 * time.Hour, + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup }) // Initialize metadata cache with grace periods @@ -68,46 +86,115 @@ func GetUniversalCacheManager(logger *Logger) *UniversalCacheManager { "issuer", }, }, - Logger: logger, + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup }) // Initialize JWK cache universalCacheManager.jwkCache = NewUniversalCache(UniversalCacheConfig{ - Type: CacheTypeJWK, - MaxSize: 200, - DefaultTTL: 1 * time.Hour, - Logger: logger, + Type: CacheTypeJWK, + MaxSize: 200, + DefaultTTL: 1 * time.Hour, + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup }) // Initialize session cache - CRITICAL FIX: Reduced from 10000 to 2000 universalCacheManager.sessionCache = NewUniversalCache(UniversalCacheConfig{ - Type: CacheTypeSession, - MaxSize: 2000, // CRITICAL FIX: Reduced from 10000 to 2000 items - MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit - DefaultTTL: 30 * time.Minute, - Logger: logger, + Type: CacheTypeSession, + MaxSize: 2000, // CRITICAL FIX: Reduced from 10000 to 2000 items + MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit + DefaultTTL: 30 * time.Minute, + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup }) // Initialize introspection cache for OAuth 2.0 Token Introspection (RFC 7662) universalCacheManager.introspectionCache = NewUniversalCache(UniversalCacheConfig{ - Type: CacheTypeToken, // Use token cache type for introspection results - MaxSize: 1000, // Cache up to 1000 introspection results - DefaultTTL: 5 * time.Minute, // Short TTL for security (introspect frequently) - Logger: logger, + Type: CacheTypeToken, // Use token cache type for introspection results + MaxSize: 1000, // Cache up to 1000 introspection results + DefaultTTL: 5 * time.Minute, // Short TTL for security (introspect frequently) + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup }) // Initialize token type cache for performance optimization universalCacheManager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{ - Type: CacheTypeToken, // Use token cache type for token type detection - MaxSize: 2000, // Cache up to 2000 token type detections - DefaultTTL: 5 * time.Minute, // 5 minute TTL for token type detection - Logger: logger, + Type: CacheTypeToken, // Use token cache type for token type detection + MaxSize: 2000, // Cache up to 2000 token type detections + DefaultTTL: 5 * time.Minute, // 5 minute TTL for token type detection + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup }) + + // Start single consolidated cleanup goroutine for all caches + // This replaces 7 individual cleanup goroutines with 1 + universalCacheManager.startConsolidatedCleanup() }) return universalCacheManager } +// startConsolidatedCleanup starts a single cleanup goroutine for all caches +// This reduces goroutine count from 7 to 1 and consolidates cleanup operations +func (m *UniversalCacheManager) startConsolidatedCleanup() { + m.mu.Lock() + if m.cleanupStarted { + m.mu.Unlock() + return + } + m.cleanupStarted = true + m.mu.Unlock() + + m.wg.Add(1) + go func() { + defer m.wg.Done() + + // Use 5-minute interval for consolidated cleanup + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.performConsolidatedCleanup() + } + } + }() + + m.logger.Info("UniversalCacheManager: Started consolidated cleanup routine for all caches") +} + +// performConsolidatedCleanup runs cleanup on all caches in sequence +// This is more efficient than parallel cleanup as it reduces lock contention +func (m *UniversalCacheManager) performConsolidatedCleanup() { + m.mu.RLock() + caches := []*UniversalCache{ + m.tokenCache, + m.blacklistCache, + m.metadataCache, + m.jwkCache, + m.sessionCache, + m.introspectionCache, + m.tokenTypeCache, + } + m.mu.RUnlock() + + totalCleaned := 0 + for _, cache := range caches { + if cache != nil { + // Each cache.Cleanup() is self-contained and handles its own locking + cache.Cleanup() + } + } + + if totalCleaned > 0 { + m.logger.Debugf("UniversalCacheManager: Consolidated cleanup completed for all caches") + } +} + // GetTokenCache returns the token cache func (m *UniversalCacheManager) GetTokenCache() *UniversalCache { m.mu.RLock() @@ -157,8 +244,16 @@ func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache { return m.tokenTypeCache } -// Close shuts down all caches +// Close shuts down all caches and the consolidated cleanup routine func (m *UniversalCacheManager) Close() error { + // Stop the consolidated cleanup routine first + if m.cancel != nil { + m.cancel() + } + + // Wait for cleanup routine to finish + m.wg.Wait() + m.mu.Lock() defer m.mu.Unlock() @@ -170,7 +265,8 @@ func (m *UniversalCacheManager) Close() error { } } - m.logger.Info("UniversalCacheManager: Closed all caches") + m.cleanupStarted = false + m.logger.Info("UniversalCacheManager: Closed all caches and cleanup routine") return nil }