diff --git a/.gitignore b/.gitignore index 412cb1e..e4faada 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ docker/ .claude/*.out *.test +.leann/ diff --git a/.traefik.yml b/.traefik.yml index 1967079..f3f65cc 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -1021,6 +1021,79 @@ configuration: See: https://github.com/lukaszraczylo/traefikoidc/issues/64 required: false + enableBackchannelLogout: + type: boolean + description: | + Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST). + + When enabled, the middleware accepts logout tokens at the configured backchannelLogoutURL. + The IdP sends a signed JWT (logout_token) to notify the application that a user's session + should be terminated. + + This implements the OIDC Back-Channel Logout 1.0 specification. + See: https://openid.net/specs/openid-connect-backchannel-1_0.html + + Requirements: + - backchannelLogoutURL must be configured + - The IdP must be configured to send logout tokens to your backchannel URL + - Logout tokens are validated using the IdP's JWKS + + Default: false + required: false + + backchannelLogoutURL: + type: string + description: | + Path for receiving backchannel logout tokens from the IdP. + + This endpoint receives POST requests with a logout_token JWT in the request body. + The token is validated against the IdP's JWKS and contains the session ID (sid) + and/or subject (sub) to invalidate. + + Example: /backchannel-logout + + The full URL to configure in your IdP would be: + https://your-app.example.com/backchannel-logout + + Note: This path should be unique and not conflict with your application routes. + required: false + + enableFrontchannelLogout: + type: boolean + description: | + Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe). + + When enabled, the middleware accepts logout requests at the configured frontchannelLogoutURL. + The IdP embeds an iframe pointing to this URL when the user logs out, allowing the + application to clear the user's session. + + This implements the OIDC Front-Channel Logout 1.0 specification. + See: https://openid.net/specs/openid-connect-frontchannel-1_0.html + + Requirements: + - frontchannelLogoutURL must be configured + - The IdP must be configured with your front-channel logout URL + - Your CSP headers must allow being embedded in an iframe from the IdP + + Default: false + required: false + + frontchannelLogoutURL: + type: string + description: | + Path for receiving front-channel logout requests from the IdP. + + This endpoint receives GET requests with optional sid (session ID) and iss (issuer) + query parameters. When called, it invalidates the user's session. + + Example: /frontchannel-logout + + The full URL to configure in your IdP would be: + https://your-app.example.com/frontchannel-logout + + Note: This path should be unique and not conflict with your application routes. + required: false + headers: type: array description: | diff --git a/README.md b/README.md index ab6b10f..bfc026d 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,10 @@ The middleware supports the following configuration options: | `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` | | `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` | | `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` | +| `enableBackchannelLogout` | Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST) | `false` | `true` | +| `backchannelLogoutURL` | The path for receiving backchannel logout tokens from the IdP | none | `/backchannel-logout` | +| `enableFrontchannelLogout` | Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe) | `false` | `true` | +| `frontchannelLogoutURL` | The path for receiving front-channel logout requests from the IdP | none | `/frontchannel-logout` | | `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section | > **⚠️ IMPORTANT - TLS Termination at Load Balancer:** @@ -1148,6 +1152,50 @@ spec: - roles # Appended to defaults: ["openid", "profile", "email", "roles"] ``` +### With IdP-Initiated Logout (Backchannel & Front-Channel) + +This plugin supports [OIDC Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) and [OIDC Front-Channel Logout](https://openid.net/specs/openid-connect-frontchannel-1_0.html) for IdP-initiated single logout. + +**Backchannel Logout** (recommended): The IdP sends a server-to-server POST request with a signed `logout_token` JWT when a user logs out. + +**Front-Channel Logout**: The IdP loads an iframe with the logout URL to invalidate the session in the browser. + +```yaml +apiVersion: traefik.io/v1alpha1 +kind: Middleware +metadata: + name: oidc-with-idp-logout + namespace: traefik +spec: + plugin: + traefikoidc: + providerURL: https://auth.example.com + clientID: your-client-id + clientSecret: your-client-secret + sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long + callbackURL: /oauth2/callback + logoutURL: /oauth2/logout # RP-initiated logout + + # Backchannel Logout (server-to-server) + enableBackchannelLogout: true + backchannelLogoutURL: /backchannel-logout + + # Front-Channel Logout (iframe-based) + enableFrontchannelLogout: true + frontchannelLogoutURL: /frontchannel-logout + + # For multi-replica deployments, use Redis to share session invalidations + redis: + enabled: true + address: redis:6379 +``` + +> **Note**: For multi-replica deployments, you **must** enable Redis to share session invalidation state across all instances. Otherwise, a logout on one instance won't invalidate sessions on other instances. + +**IdP Configuration**: Configure your IdP to send logout requests to: +- **Backchannel**: `https://your-app.example.com/backchannel-logout` (POST with `logout_token`) +- **Front-Channel**: `https://your-app.example.com/frontchannel-logout?sid=SESSION_ID&iss=ISSUER` (GET in iframe) + ### With Templated Headers ```yaml diff --git a/cache_manager.go b/cache_manager.go index e3997c8..8b59d57 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -104,6 +104,14 @@ func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface { return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true} } +// GetSharedSessionInvalidationCache returns the shared session invalidation cache +// for backchannel and front-channel logout (IdP-initiated logout) +func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface { + cm.mu.RLock() + defer cm.mu.RUnlock() + return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true} +} + // Close gracefully shuts down all cache components func (cm *CacheManager) Close() error { cm.mu.Lock() diff --git a/docs/index.html b/docs/index.html index 63d5c39..33a9bf1 100644 --- a/docs/index.html +++ b/docs/index.html @@ -90,6 +90,7 @@ Configuration Deployment Security + Logout
@@ -1219,6 +1221,71 @@ spec: + +
+
+
+

IdP-Initiated Logout

+

Support for OIDC Back-Channel and Front-Channel Logout specifications

+
+
+
+
+

+ + Back-Channel Logout +

+

+ Server-to-server logout notification. The IdP sends a signed JWT (logout_token) directly to your application when a user logs out. +

+
    +
  • • Signed JWT logout tokens
  • +
  • • Session ID (sid) based invalidation
  • +
  • • Subject (sub) based invalidation
  • +
  • • Works behind firewalls
  • +
+
+
+

+ + Front-Channel Logout +

+

+ Browser-based logout via iframe. The IdP embeds an iframe pointing to your logout endpoint during user logout. +

+
    +
  • • Iframe-based session termination
  • +
  • • Immediate cookie invalidation
  • +
  • • Simple GET request handling
  • +
  • • Issuer validation
  • +
+
+
+
+

Configuration Example

+
http:
+  middlewares:
+    oidc-auth:
+      plugin:
+        traefikoidc:
+          # ... other OIDC configuration ...
+
+          # Back-Channel Logout (server-to-server)
+          enableBackchannelLogout: true
+          backchannelLogoutURL: "/backchannel-logout"
+
+          # Front-Channel Logout (browser-based)
+          enableFrontchannelLogout: true
+          frontchannelLogoutURL: "/frontchannel-logout"
+

+ Configure your IdP with the full URLs (e.g., https://your-app.example.com/backchannel-logout). + When a user logs out from the IdP, all their sessions across your applications will be invalidated. +

+
+
+
+
+
diff --git a/helpers.go b/helpers.go index 85ffaa4..7a8b407 100644 --- a/helpers.go +++ b/helpers.go @@ -336,6 +336,7 @@ func createStringMap(keys []string) map[string]struct{} { // and redirects to the provider's logout endpoint or configured post-logout URI. // It handles potential errors during session retrieval or clearing. func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { + t.logger.Debug("Processing logout request") session, err := t.sessionManager.GetSession(req) if err != nil { t.logger.Errorf("Error getting session: %v", err) diff --git a/logout.go b/logout.go new file mode 100644 index 0000000..7c990c4 --- /dev/null +++ b/logout.go @@ -0,0 +1,502 @@ +// Package traefikoidc provides OIDC authentication middleware for Traefik. +// This file implements OIDC Backchannel Logout (OpenID Connect Back-Channel Logout 1.0) +// and Front-Channel Logout (OpenID Connect Front-Channel Logout 1.0) functionality. +package traefikoidc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + // logoutTokenType is the expected typ claim for logout tokens + // #nosec G101 -- This is a JWT type claim value from OIDC spec, not a credential + logoutTokenType = "logout+jwt" + + // sessionInvalidationTTL is how long to remember invalidated sessions + // Should be at least as long as your session max age + sessionInvalidationTTL = 25 * time.Hour +) + +// LogoutTokenClaims represents the claims in an OIDC logout token +// as defined in OpenID Connect Back-Channel Logout 1.0 +type LogoutTokenClaims struct { + Issuer string `json:"iss"` + Subject string `json:"sub,omitempty"` + Audience interface{} `json:"aud"` // Can be string or []string + IssuedAt int64 `json:"iat"` + JTI string `json:"jti"` + Events map[string]interface{} `json:"events"` + SessionID string `json:"sid,omitempty"` + Nonce string `json:"nonce,omitempty"` // Must NOT be present +} + +// handleBackchannelLogout processes OIDC Backchannel Logout requests. +// It accepts POST requests with a logout_token parameter containing a JWT +// that identifies which session(s) to terminate. +// +// According to OpenID Connect Back-Channel Logout 1.0: +// - The logout_token is a JWT signed by the IdP +// - It contains either a 'sid' (session ID) or 'sub' (subject) claim to identify the session +// - The RP must validate the token and invalidate the matching session(s) +// +// Parameters: +// - rw: The HTTP response writer +// - req: The HTTP request containing the logout_token +func (t *TraefikOidc) handleBackchannelLogout(rw http.ResponseWriter, req *http.Request) { + t.logger.Debug("Processing backchannel logout request") + + // Backchannel logout must be POST + if req.Method != http.MethodPost { + t.logger.Errorf("Backchannel logout: invalid method %s, expected POST", req.Method) + http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse form data to get logout_token + if err := req.ParseForm(); err != nil { + t.logger.Errorf("Backchannel logout: failed to parse form: %v", err) + http.Error(rw, "Bad request", http.StatusBadRequest) + return + } + + logoutToken := req.FormValue("logout_token") + if logoutToken == "" { + // Also try reading from request body as raw JWT + body, err := io.ReadAll(io.LimitReader(req.Body, 64*1024)) // 64KB limit + if err == nil && len(body) > 0 { + logoutToken = string(body) + } + } + + if logoutToken == "" { + t.logger.Error("Backchannel logout: missing logout_token") + http.Error(rw, "logout_token required", http.StatusBadRequest) + return + } + + // Parse and validate the logout token + claims, err := t.validateLogoutToken(logoutToken) + if err != nil { + t.logger.Errorf("Backchannel logout: token validation failed: %v", err) + // Return 400 for invalid token per spec + http.Error(rw, "Invalid logout token", http.StatusBadRequest) + return + } + + // Invalidate session(s) based on sid or sub + if err := t.invalidateSession(claims.SessionID, claims.Subject); err != nil { + t.logger.Errorf("Backchannel logout: failed to invalidate session: %v", err) + http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError) + return + } + + t.logger.Infof("Backchannel logout: successfully invalidated session (sid=%s, sub=%s)", + claims.SessionID, claims.Subject) + + // Return 200 OK with empty body per spec + rw.WriteHeader(http.StatusOK) +} + +// handleFrontchannelLogout processes OIDC Front-Channel Logout requests. +// It accepts GET requests with 'iss' and 'sid' query parameters that identify +// which session to terminate. The IdP typically loads this URL in an iframe. +// +// According to OpenID Connect Front-Channel Logout 1.0: +// - The request contains 'iss' (issuer) and optionally 'sid' (session ID) +// - The RP should clear the session and return a response (typically empty or image) +// - The response must be cacheable to allow the IdP to load it in an iframe +// +// Parameters: +// - rw: The HTTP response writer +// - req: The HTTP request containing iss and sid parameters +func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http.Request) { + t.logger.Debug("Processing front-channel logout request") + + // Front-channel logout should be GET + if req.Method != http.MethodGet { + t.logger.Errorf("Front-channel logout: invalid method %s, expected GET", req.Method) + http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Get iss and sid from query parameters + iss := req.URL.Query().Get("iss") + sid := req.URL.Query().Get("sid") + + // Validate issuer matches our expected issuer + t.metadataMu.RLock() + expectedIssuer := t.issuerURL + t.metadataMu.RUnlock() + + if iss != "" && iss != expectedIssuer { + t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer) + http.Error(rw, "Invalid issuer", http.StatusBadRequest) + return + } + + // Must have at least sid for front-channel logout + if sid == "" { + t.logger.Error("Front-channel logout: missing sid parameter") + http.Error(rw, "sid parameter required", http.StatusBadRequest) + return + } + + // Invalidate the session + if err := t.invalidateSession(sid, ""); err != nil { + t.logger.Errorf("Front-channel logout: failed to invalidate session: %v", err) + http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError) + return + } + + t.logger.Infof("Front-channel logout: successfully invalidated session (sid=%s)", sid) + + // Return a minimal HTML response that's suitable for iframe loading + // Set headers to allow embedding and caching + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Header().Set("Cache-Control", "no-cache, no-store") + rw.Header().Set("Pragma", "no-cache") + // Allow embedding in iframes from any origin (required for front-channel logout) + rw.Header().Del("X-Frame-Options") + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte("Logged Out")) +} + +// validateLogoutToken parses and validates a logout token JWT. +// It verifies the token signature, issuer, audience, and required claims. +// +// Parameters: +// - tokenString: The raw JWT logout token +// +// Returns: +// - The parsed logout token claims +// - An error if validation fails +func (t *TraefikOidc) validateLogoutToken(tokenString string) (*LogoutTokenClaims, error) { + // Parse the JWT + jwt, err := parseJWT(tokenString) + if err != nil { + return nil, fmt.Errorf("failed to parse logout token: %w", err) + } + + // Check token type if present + if typ, ok := jwt.Header["typ"].(string); ok { + // The typ should be "logout+jwt" or omitted + if typ != "" && typ != logoutTokenType && typ != "JWT" { + return nil, fmt.Errorf("invalid token type: %s", typ) + } + } + + // Verify signature only (not standard claims - logout tokens don't have 'exp') + if err := t.verifyLogoutTokenSignature(jwt, tokenString); err != nil { + return nil, fmt.Errorf("signature verification failed: %w", err) + } + + // Extract claims + claims := &LogoutTokenClaims{} + claimsJSON, err := json.Marshal(jwt.Claims) + if err != nil { + return nil, fmt.Errorf("failed to marshal claims: %w", err) + } + if err := json.Unmarshal(claimsJSON, claims); err != nil { + return nil, fmt.Errorf("failed to unmarshal claims: %w", err) + } + + // Validate required claims + t.metadataMu.RLock() + expectedIssuer := t.issuerURL + t.metadataMu.RUnlock() + + // Validate issuer + if claims.Issuer != expectedIssuer { + return nil, fmt.Errorf("issuer mismatch: got %s, expected %s", claims.Issuer, expectedIssuer) + } + + // Validate audience (must contain our client_id) + if !t.validateLogoutTokenAudience(claims.Audience) { + return nil, fmt.Errorf("audience validation failed") + } + + // Validate iat (issued at) - must be present and not too old + if claims.IssuedAt == 0 { + return nil, fmt.Errorf("missing iat claim") + } + iatTime := time.Unix(claims.IssuedAt, 0) + // Allow up to 5 minutes clock skew and 10 minutes token age + if time.Since(iatTime) > 15*time.Minute { + return nil, fmt.Errorf("logout token too old: issued at %v", iatTime) + } + // Token should not be from the future (with 5 min clock skew tolerance) + if iatTime.After(time.Now().Add(5 * time.Minute)) { + return nil, fmt.Errorf("logout token issued in the future: %v", iatTime) + } + + // Validate events claim - must contain the logout event + if claims.Events == nil { + return nil, fmt.Errorf("missing events claim") + } + if _, ok := claims.Events["http://schemas.openid.net/event/backchannel-logout"]; !ok { + return nil, fmt.Errorf("missing backchannel-logout event in events claim") + } + + // Validate that nonce is NOT present (per spec) + if claims.Nonce != "" { + return nil, fmt.Errorf("nonce claim must not be present in logout token") + } + + // Must have either sid or sub (or both) + if claims.SessionID == "" && claims.Subject == "" { + return nil, fmt.Errorf("logout token must contain either sid or sub claim") + } + + return claims, nil +} + +// validateLogoutTokenAudience checks if the logout token audience contains our client_id +func (t *TraefikOidc) validateLogoutTokenAudience(aud interface{}) bool { + switch v := aud.(type) { + case string: + return v == t.clientID + case []interface{}: + for _, a := range v { + if s, ok := a.(string); ok && s == t.clientID { + return true + } + } + case []string: + for _, a := range v { + if a == t.clientID { + return true + } + } + } + return false +} + +// verifyLogoutTokenSignature verifies only the signature of a logout token. +// Unlike VerifyJWTSignatureAndClaims, this does NOT validate standard claims like 'exp' +// because logout tokens don't have an expiration claim per OIDC Back-Channel Logout spec. +// +// Parameters: +// - jwt: The parsed JWT structure +// - tokenString: The raw token string for signature verification +// +// Returns: +// - An error if signature verification fails +func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) error { + t.logger.Debug("Verifying logout token signature") + + // Read jwksURL with RLock + t.metadataMu.RLock() + jwksURL := t.jwksURL + t.metadataMu.RUnlock() + + jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient) + if err != nil { + return fmt.Errorf("failed to get JWKS: %w", err) + } + + if jwks == nil { + return fmt.Errorf("JWKS is nil, cannot verify token") + } + + kid, ok := jwt.Header["kid"].(string) + if !ok || kid == "" { + return fmt.Errorf("missing key ID in token header") + } + + alg, ok := jwt.Header["alg"].(string) + if !ok || alg == "" { + return fmt.Errorf("missing algorithm in token header") + } + + // Find the matching key in JWKS + var matchingKey *JWK + for _, key := range jwks.Keys { + if key.Kid == kid { + matchingKey = &key + break + } + } + + if matchingKey == nil { + return fmt.Errorf("no matching public key found for kid: %s", kid) + } + + publicKeyPEM, err := jwkToPEM(matchingKey) + if err != nil { + return fmt.Errorf("failed to convert JWK to PEM: %w", err) + } + + if err := verifySignature(tokenString, publicKeyPEM, alg); err != nil { + return fmt.Errorf("signature verification failed: %w", err) + } + + t.logger.Debug("Logout token signature verified successfully") + return nil +} + +// invalidateSession marks a session as invalidated in the session invalidation cache. +// It stores entries by both sid and sub if available. +// +// Parameters: +// - sid: The session ID to invalidate (from the 'sid' claim) +// - sub: The subject to invalidate (from the 'sub' claim) +// +// Returns: +// - An error if the invalidation fails +func (t *TraefikOidc) invalidateSession(sid, sub string) error { + if t.sessionInvalidationCache == nil { + return fmt.Errorf("session invalidation cache not initialized") + } + + now := time.Now().Unix() + + // Store by session ID + if sid != "" { + key := t.buildSessionInvalidationKey("sid", sid) + t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL) + t.logger.Debugf("Invalidated session by sid: %s", sid) + } + + // Store by subject (invalidates all sessions for this user) + if sub != "" { + key := t.buildSessionInvalidationKey("sub", sub) + t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL) + t.logger.Debugf("Invalidated session by sub: %s", sub) + } + + return nil +} + +// isSessionInvalidated checks if a session has been invalidated via backchannel +// or front-channel logout. +// +// Parameters: +// - sid: The session ID to check +// - sub: The subject to check +// - sessionCreatedAt: When the session was created (to compare against invalidation time) +// +// Returns: +// - true if the session has been invalidated, false otherwise +func (t *TraefikOidc) isSessionInvalidated(sid, sub string, sessionCreatedAt time.Time) bool { + if t.sessionInvalidationCache == nil { + return false + } + + // Truncate session creation time to seconds for fair comparison with Unix timestamps + sessionCreatedAtSec := sessionCreatedAt.Truncate(time.Second) + + // Check by session ID first (more specific) + if sid != "" { + key := t.buildSessionInvalidationKey("sid", sid) + if val, found := t.sessionInvalidationCache.Get(key); found { + if invalidatedAt, ok := val.(int64); ok { + // Session was invalidated at or after it was created + invalidationTime := time.Unix(invalidatedAt, 0) + if !invalidationTime.Before(sessionCreatedAtSec) { + t.logger.Debugf("Session invalidated by sid: %s", sid) + return true + } + } + } + } + + // Check by subject (all sessions for this user) + if sub != "" { + key := t.buildSessionInvalidationKey("sub", sub) + if val, found := t.sessionInvalidationCache.Get(key); found { + if invalidatedAt, ok := val.(int64); ok { + // Sessions for this subject created at or before invalidation are invalid + invalidationTime := time.Unix(invalidatedAt, 0) + if !invalidationTime.Before(sessionCreatedAtSec) { + t.logger.Debugf("Session invalidated by sub: %s", sub) + return true + } + } + } + } + + return false +} + +// buildSessionInvalidationKey creates a cache key for session invalidation +func (t *TraefikOidc) buildSessionInvalidationKey(keyType, value string) string { + return fmt.Sprintf("session_invalidation:%s:%s", keyType, value) +} + +// extractSessionInfo extracts sid and sub from an ID token for session tracking +func (t *TraefikOidc) extractSessionInfo(idToken string) (sid, sub string, createdAt time.Time) { + if idToken == "" { + return "", "", time.Time{} + } + + jwt, err := parseJWT(idToken) + if err != nil { + return "", "", time.Time{} + } + + // Extract sid (session ID) + if sidVal, ok := jwt.Claims["sid"].(string); ok { + sid = sidVal + } + + // Extract sub (subject) + if subVal, ok := jwt.Claims["sub"].(string); ok { + sub = subVal + } + + // Extract iat for session creation time + if iatVal, ok := jwt.Claims["iat"].(float64); ok { + createdAt = time.Unix(int64(iatVal), 0) + } else { + // Default to now if iat not present + createdAt = time.Now() + } + + return sid, sub, createdAt +} + +// determineLogoutPath checks if the given path matches any logout URL +func (t *TraefikOidc) determineLogoutPath(path string) string { + // Check backchannel logout path + if t.backchannelLogoutPath != "" && path == t.backchannelLogoutPath { + return "backchannel" + } + + // Check front-channel logout path + if t.frontchannelLogoutPath != "" && path == t.frontchannelLogoutPath { + return "frontchannel" + } + + // Check regular logout path (for RP-initiated logout) + if path == t.logoutURLPath { + return "rp" + } + + return "" +} + +// normalizeLogoutPath ensures logout paths start with / and prevents open redirects +func normalizeLogoutPath(path string) string { + if path == "" { + return "" + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + // Prevent open redirect: ensure second character is not / or \ + // This prevents URLs like //example.com or /\example.com from being treated as absolute URLs + if len(path) > 1 && (path[1] == '/' || path[1] == '\\') { + // Strip leading slashes/backslashes and re-normalize + path = strings.TrimLeft(path, "/\\") + if path != "" { + path = "/" + path + } + } + return path +} diff --git a/logout_test.go b/logout_test.go new file mode 100644 index 0000000..4f9b359 --- /dev/null +++ b/logout_test.go @@ -0,0 +1,1623 @@ +package traefikoidc + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" +) + +// TestBackchannelLogoutBasic tests the basic backchannel logout flow +func TestBackchannelLogoutBasic(t *testing.T) { + // Create a mock cache for session invalidation + mockCache := &mockCacheInterface{} + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + } + + tests := []struct { + name string + method string + body string + contentType string + expectedStatus int + }{ + { + name: "GET method not allowed", + method: http.MethodGet, + body: "", + contentType: "", + expectedStatus: http.StatusMethodNotAllowed, + }, + { + name: "Missing logout_token", + method: http.MethodPost, + body: "", + contentType: "application/x-www-form-urlencoded", + expectedStatus: http.StatusBadRequest, + }, + { + name: "Invalid logout_token format", + method: http.MethodPost, + body: "logout_token=not-a-valid-jwt", + contentType: "application/x-www-form-urlencoded", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.method, "/backchannel-logout", strings.NewReader(tc.body)) + if tc.contentType != "" { + req.Header.Set("Content-Type", tc.contentType) + } + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != tc.expectedStatus { + t.Errorf("Expected status %d, got %d", tc.expectedStatus, rw.Code) + } + }) + } +} + +// TestFrontchannelLogoutBasic tests the basic front-channel logout flow +func TestFrontchannelLogoutBasic(t *testing.T) { + // Create a mock cache for session invalidation + mockCache := &mockCacheInterface{} + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableFrontchannelLogout: true, + frontchannelLogoutPath: "/frontchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + } + + tests := []struct { + name string + method string + queryParams map[string]string + expectedStatus int + }{ + { + name: "POST method not allowed", + method: http.MethodPost, + queryParams: map[string]string{}, + expectedStatus: http.StatusMethodNotAllowed, + }, + { + name: "Missing sid parameter", + method: http.MethodGet, + queryParams: map[string]string{"iss": "https://provider.example.com"}, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Invalid issuer", + method: http.MethodGet, + queryParams: map[string]string{"iss": "https://wrong-issuer.com", "sid": "session123"}, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Valid front-channel logout", + method: http.MethodGet, + queryParams: map[string]string{"iss": "https://provider.example.com", "sid": "session123"}, + expectedStatus: http.StatusOK, + }, + { + name: "Valid front-channel logout without issuer", + method: http.MethodGet, + queryParams: map[string]string{"sid": "session456"}, + expectedStatus: http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + urlStr := "/frontchannel-logout" + if len(tc.queryParams) > 0 { + params := url.Values{} + for k, v := range tc.queryParams { + params.Set(k, v) + } + urlStr += "?" + params.Encode() + } + + req := httptest.NewRequest(tc.method, urlStr, nil) + rw := httptest.NewRecorder() + + oidc.handleFrontchannelLogout(rw, req) + + if rw.Code != tc.expectedStatus { + t.Errorf("Expected status %d, got %d", tc.expectedStatus, rw.Code) + } + + // For successful logout, verify response headers + if tc.expectedStatus == http.StatusOK { + // Should not have X-Frame-Options (to allow iframe embedding) + if rw.Header().Get("X-Frame-Options") != "" { + t.Error("Expected X-Frame-Options to be removed for front-channel logout") + } + // Should have HTML content type + contentType := rw.Header().Get("Content-Type") + if !strings.Contains(contentType, "text/html") { + t.Errorf("Expected HTML content type, got %s", contentType) + } + } + }) + } +} + +// TestSessionInvalidation tests session invalidation storage and retrieval +func TestSessionInvalidation(t *testing.T) { + mockCache := &mockCacheInterface{ + data: make(map[string]interface{}), + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + sessionInvalidationCache: mockCache, + } + + // Test invalidating by session ID + err := oidc.invalidateSession("session123", "") + if err != nil { + t.Fatalf("Failed to invalidate session by sid: %v", err) + } + + // Verify the session was invalidated + key := oidc.buildSessionInvalidationKey("sid", "session123") + if _, found := mockCache.data[key]; !found { + t.Error("Session invalidation by sid was not stored") + } + + // Test invalidating by subject + err = oidc.invalidateSession("", "user@example.com") + if err != nil { + t.Fatalf("Failed to invalidate session by sub: %v", err) + } + + // Verify the subject was invalidated + key = oidc.buildSessionInvalidationKey("sub", "user@example.com") + if _, found := mockCache.data[key]; !found { + t.Error("Session invalidation by sub was not stored") + } +} + +// TestIsSessionInvalidated tests checking if a session is invalidated +func TestIsSessionInvalidated(t *testing.T) { + mockCache := &mockCacheInterface{ + data: make(map[string]interface{}), + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + sessionInvalidationCache: mockCache, + } + + // Session created now + sessionCreatedAt := time.Now() + + // Initially, session should not be invalidated + if oidc.isSessionInvalidated("session123", "user@example.com", sessionCreatedAt) { + t.Error("Session should not be invalidated initially") + } + + // Invalidate the session + _ = oidc.invalidateSession("session123", "") + + // Now session should be invalidated + if !oidc.isSessionInvalidated("session123", "", sessionCreatedAt) { + t.Error("Session should be invalidated after invalidateSession call") + } + + // Session created after invalidation should not be affected + futureSession := time.Now().Add(1 * time.Hour) + if oidc.isSessionInvalidated("session123", "", futureSession) { + t.Error("Session created after invalidation should not be affected") + } +} + +// TestLogoutTokenValidation tests logout token claim validation +func TestLogoutTokenValidation(t *testing.T) { + tests := []struct { + name string + claims *LogoutTokenClaims + expectError bool + errorMsg string + }{ + { + name: "Missing events claim", + claims: &LogoutTokenClaims{ + Issuer: "https://provider.example.com", + Audience: "test-client", + IssuedAt: time.Now().Unix(), + SessionID: "session123", + }, + expectError: true, + errorMsg: "missing events claim", + }, + { + name: "Missing both sid and sub", + claims: &LogoutTokenClaims{ + Issuer: "https://provider.example.com", + Audience: "test-client", + IssuedAt: time.Now().Unix(), + Events: map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + }, + expectError: true, + errorMsg: "must contain either sid or sub", + }, + { + name: "Nonce present (not allowed)", + claims: &LogoutTokenClaims{ + Issuer: "https://provider.example.com", + Audience: "test-client", + IssuedAt: time.Now().Unix(), + SessionID: "session123", + Nonce: "should-not-be-here", + Events: map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + }, + expectError: true, + errorMsg: "nonce claim must not be present", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // We can't directly test validateLogoutToken without a properly signed JWT, + // but we can verify the validation logic through the claims struct + if tc.claims.Events == nil && tc.expectError && strings.Contains(tc.errorMsg, "events") { + // Events validation would fail + } + if tc.claims.SessionID == "" && tc.claims.Subject == "" && tc.expectError && strings.Contains(tc.errorMsg, "sid or sub") { + // sid/sub validation would fail + } + if tc.claims.Nonce != "" && tc.expectError && strings.Contains(tc.errorMsg, "nonce") { + // nonce validation would fail + } + }) + } +} + +// TestLogoutTokenAudienceValidation tests audience validation for logout tokens +func TestLogoutTokenAudienceValidation(t *testing.T) { + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + clientID: "test-client", + } + + tests := []struct { + name string + audience interface{} + valid bool + }{ + { + name: "String audience matching client ID", + audience: "test-client", + valid: true, + }, + { + name: "String audience not matching", + audience: "other-client", + valid: false, + }, + { + name: "Array audience containing client ID", + audience: []interface{}{"other-client", "test-client"}, + valid: true, + }, + { + name: "Array audience not containing client ID", + audience: []interface{}{"other-client", "another-client"}, + valid: false, + }, + { + name: "String array audience containing client ID", + audience: []string{"other-client", "test-client"}, + valid: true, + }, + { + name: "Nil audience", + audience: nil, + valid: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := oidc.validateLogoutTokenAudience(tc.audience) + if result != tc.valid { + t.Errorf("Expected %v, got %v", tc.valid, result) + } + }) + } +} + +// TestExtractSessionInfo tests extraction of session info from ID tokens +func TestExtractSessionInfo(t *testing.T) { + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + } + + // Test with empty token + sid, sub, createdAt := oidc.extractSessionInfo("") + if sid != "" || sub != "" || !createdAt.IsZero() { + t.Error("Empty token should return empty values") + } + + // Test with invalid token + sid, sub, createdAt = oidc.extractSessionInfo("not-a-valid-jwt") + if sid != "" || sub != "" || !createdAt.IsZero() { + t.Error("Invalid token should return empty values") + } + + // Create a simple unsigned JWT for testing (header.claims.signature) + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + now := time.Now().Unix() + claimsJSON := fmt.Sprintf(`{"sid":"test-session-id","sub":"test-subject","iat":%d}`, now) + claims := base64.RawURLEncoding.EncodeToString([]byte(claimsJSON)) + testToken := header + "." + claims + "." + + sid, sub, createdAt = oidc.extractSessionInfo(testToken) + if sid != "test-session-id" { + t.Errorf("Expected sid 'test-session-id', got '%s'", sid) + } + if sub != "test-subject" { + t.Errorf("Expected sub 'test-subject', got '%s'", sub) + } + if createdAt.Unix() != now { + t.Errorf("Expected createdAt %d, got %d", now, createdAt.Unix()) + } +} + +// TestMiddlewareBackchannelLogoutRouting tests that backchannel logout requests are routed correctly +func TestMiddlewareBackchannelLogoutRouting(t *testing.T) { + mockCache := &mockCacheInterface{ + data: make(map[string]interface{}), + } + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("next handler called")) + }) + + oidc := &TraefikOidc{ + next: nextHandler, + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + initComplete: make(chan struct{}), + firstRequestReceived: true, + metadataRefreshStarted: true, + logoutURLPath: "/logout", + } + close(oidc.initComplete) + + // Request to backchannel logout path should be handled + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", nil) + rw := httptest.NewRecorder() + + oidc.ServeHTTP(rw, req) + + // Should return 400 (bad request) because no logout_token provided + // but importantly should NOT call next handler + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing logout_token, got %d", rw.Code) + } + if strings.Contains(rw.Body.String(), "next handler called") { + t.Error("Backchannel logout should not call next handler") + } +} + +// TestMiddlewareFrontchannelLogoutRouting tests that front-channel logout requests are routed correctly +func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) { + mockCache := &mockCacheInterface{ + data: make(map[string]interface{}), + } + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("next handler called")) + }) + + oidc := &TraefikOidc{ + next: nextHandler, + logger: NewLogger("debug"), + enableFrontchannelLogout: true, + frontchannelLogoutPath: "/frontchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + initComplete: make(chan struct{}), + firstRequestReceived: true, + metadataRefreshStarted: true, + logoutURLPath: "/logout", + } + close(oidc.initComplete) + + // Request to front-channel logout path with valid sid should succeed + req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session", nil) + rw := httptest.NewRecorder() + + oidc.ServeHTTP(rw, req) + + // Should return 200 OK + if rw.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rw.Code) + } + if strings.Contains(rw.Body.String(), "next handler called") { + t.Error("Front-channel logout should not call next handler") + } +} + +// TestNormalizeLogoutPath tests the path normalization function +func TestNormalizeLogoutPath(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", ""}, + {"/logout", "/logout"}, + {"logout", "/logout"}, + {"/backchannel-logout", "/backchannel-logout"}, + {"backchannel-logout", "/backchannel-logout"}, + // Security: prevent open redirect via // + {"//evil.com", "/evil.com"}, + {"//evil.com/path", "/evil.com/path"}, + // Security: prevent open redirect via /\ + {"/\\evil.com", "/evil.com"}, + {"/\\evil.com/path", "/evil.com/path"}, + // Security: multiple leading slashes + {"///example.com", "/example.com"}, + // Security: mixed slashes + {"//\\example.com", "/example.com"}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + result := normalizeLogoutPath(tc.input) + if result != tc.expected { + t.Errorf("normalizeLogoutPath(%q) = %q, expected %q", tc.input, result, tc.expected) + } + }) + } +} + +// mockCacheInterface implements CacheInterface for testing +type mockCacheInterface struct { + mu sync.Mutex + data map[string]interface{} +} + +func (m *mockCacheInterface) Set(key string, value interface{}, ttl time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]interface{}) + } + m.data[key] = value +} + +func (m *mockCacheInterface) Get(key string) (interface{}, bool) { + m.mu.Lock() + defer m.mu.Unlock() + if m.data == nil { + return nil, false + } + val, found := m.data[key] + return val, found +} + +func (m *mockCacheInterface) Delete(key string) { + m.mu.Lock() + defer m.mu.Unlock() + if m.data != nil { + delete(m.data, key) + } +} + +func (m *mockCacheInterface) SetMaxSize(size int) {} +func (m *mockCacheInterface) Size() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.data) +} +func (m *mockCacheInterface) Clear() { + m.mu.Lock() + defer m.mu.Unlock() + m.data = make(map[string]interface{}) +} +func (m *mockCacheInterface) Cleanup() {} +func (m *mockCacheInterface) Close() {} +func (m *mockCacheInterface) GetStats() map[string]interface{} { + m.mu.Lock() + defer m.mu.Unlock() + return map[string]interface{}{"size": len(m.data)} +} + +// TestBackchannelLogoutWithValidToken tests backchannel logout with a properly formatted (but unsigned) token +func TestBackchannelLogoutWithValidToken(t *testing.T) { + // This test verifies the token parsing and validation logic + // Note: In production, the token would need to be properly signed by the IdP + mockCache := &mockCacheInterface{ + data: make(map[string]interface{}), + } + + // Create mock JWK cache that returns keys + mockJWKCache := &mockJWKCacheForLogout{} + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + } + + // Create a minimal logout token structure (this won't pass signature verification + // but tests the parsing logic) + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"ES256","typ":"logout+jwt"}`)) + now := time.Now().Unix() + claimsJSON := fmt.Sprintf(`{ + "iss":"https://provider.example.com", + "aud":"test-client", + "iat":%d, + "jti":"unique-id-123", + "events":{"http://schemas.openid.net/event/backchannel-logout":{}}, + "sid":"session-to-logout" + }`, now) + claims := base64.RawURLEncoding.EncodeToString([]byte(claimsJSON)) + logoutToken := header + "." + claims + ".fake-signature" + + // This should fail because of invalid signature, but we can verify + // the token parsing works up to signature verification + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + // Should fail with 400 due to signature verification failure + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", rw.Code) + } +} + +// mockJWKCacheForLogout implements JWKCacheInterface for testing +type mockJWKCacheForLogout struct{} + +func (m *mockJWKCacheForLogout) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { + // Generate a test ECDSA key pair + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + // Convert public key to JWK format + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + return &JWKSet{ + Keys: []JWK{ + { + Kty: "EC", + Crv: "P-256", + X: x, + Y: y, + Kid: "test-key-1", + Use: "sig", + Alg: "ES256", + }, + }, + }, nil +} + +func (m *mockJWKCacheForLogout) Clear() {} +func (m *mockJWKCacheForLogout) Cleanup() {} +func (m *mockJWKCacheForLogout) Close() {} + +// TestBackchannelLogoutIntegration tests the full backchannel logout flow with a properly signed token +func TestBackchannelLogoutIntegration(t *testing.T) { + // Generate ECDSA key pair for signing + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + mockCache := &mockCacheInterface{ + data: make(map[string]interface{}), + } + + // Create JWK cache that returns our test key + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{ + { + Kty: "EC", + Crv: "P-256", + X: x, + Y: y, + Kid: "test-key-1", + Use: "sig", + Alg: "ES256", + }, + }, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Create and sign a valid logout token + header := map[string]interface{}{ + "alg": "ES256", + "typ": "logout+jwt", + "kid": "test-key-1", + } + headerJSON, _ := json.Marshal(header) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + + now := time.Now().Unix() + claims := map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": now, + "jti": "unique-id-123", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-to-logout", + } + claimsJSON, _ := json.Marshal(claims) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + + // Sign the token + signingInput := headerB64 + "." + claimsB64 + hash := sha256.Sum256([]byte(signingInput)) + r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:]) + if err != nil { + t.Fatalf("Failed to sign token: %v", err) + } + + // Convert signature to fixed-size format (32 bytes each for P-256) + rBytes := r.Bytes() + sBytes := s.Bytes() + sigBytes := make([]byte, 64) + copy(sigBytes[32-len(rBytes):32], rBytes) + copy(sigBytes[64-len(sBytes):], sBytes) + signatureB64 := base64.RawURLEncoding.EncodeToString(sigBytes) + + logoutToken := headerB64 + "." + claimsB64 + "." + signatureB64 + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + // Should succeed with 200 OK + if rw.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", rw.Code, rw.Body.String()) + } + + // Verify session was invalidated + key := oidc.buildSessionInvalidationKey("sid", "session-to-logout") + if _, found := mockCache.data[key]; !found { + t.Error("Session should have been invalidated") + } +} + +// staticJWKCache returns a static JWKS for testing +type staticJWKCache struct { + jwks *JWKSet +} + +func (s *staticJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { + return s.jwks, nil +} + +func (s *staticJWKCache) Clear() {} +func (s *staticJWKCache) Cleanup() {} +func (s *staticJWKCache) Close() {} + +// TestDetermineLogoutPath tests the logout path determination function +func TestDetermineLogoutPath(t *testing.T) { + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + logoutURLPath: "/logout", + backchannelLogoutPath: "/backchannel-logout", + frontchannelLogoutPath: "/frontchannel-logout", + } + + tests := []struct { + path string + expected string + }{ + {"/logout", "rp"}, + {"/backchannel-logout", "backchannel"}, + {"/frontchannel-logout", "frontchannel"}, + {"/api/resource", ""}, + {"/", ""}, + } + + for _, tc := range tests { + t.Run(tc.path, func(t *testing.T) { + result := oidc.determineLogoutPath(tc.path) + if result != tc.expected { + t.Errorf("determineLogoutPath(%q) = %q, expected %q", tc.path, result, tc.expected) + } + }) + } +} + +// TestSessionInvalidationWithNilCache tests that session invalidation handles nil cache gracefully +func TestSessionInvalidationWithNilCache(t *testing.T) { + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + sessionInvalidationCache: nil, + } + + // Should return error for nil cache + err := oidc.invalidateSession("session123", "") + if err == nil { + t.Error("Expected error for nil cache") + } + + // isSessionInvalidated should return false for nil cache + if oidc.isSessionInvalidated("session123", "", time.Now()) { + t.Error("Expected false for nil cache") + } +} + +// TestBackchannelLogoutWithSubOnly tests logout with subject claim only (no sid) +func TestBackchannelLogoutWithSubOnly(t *testing.T) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Unix(), + "jti": "unique-id-sub-only", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sub": "user@example.com", // Only sub, no sid + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", rw.Code, rw.Body.String()) + } + + // Verify subject was invalidated + key := oidc.buildSessionInvalidationKey("sub", "user@example.com") + if _, found := mockCache.data[key]; !found { + t.Error("Subject should have been invalidated") + } +} + +// TestBackchannelLogoutWithBothSidAndSub tests logout with both sid and sub claims +func TestBackchannelLogoutWithBothSidAndSub(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Unix(), + "jti": "unique-id-both", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-123", + "sub": "user@example.com", + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rw.Code) + } + + // Both sid and sub should be invalidated + sidKey := oidc.buildSessionInvalidationKey("sid", "session-123") + subKey := oidc.buildSessionInvalidationKey("sub", "user@example.com") + if _, found := mockCache.data[sidKey]; !found { + t.Error("Session ID should have been invalidated") + } + if _, found := mockCache.data[subKey]; !found { + t.Error("Subject should have been invalidated") + } +} + +// TestBackchannelLogoutWrongIssuer tests that wrong issuer is rejected +func TestBackchannelLogoutWrongIssuer(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://wrong-issuer.com", // Wrong issuer + "aud": "test-client", + "iat": time.Now().Unix(), + "jti": "unique-id-wrong-iss", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-123", + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for wrong issuer, got %d", rw.Code) + } +} + +// TestBackchannelLogoutWrongAudience tests that wrong audience is rejected +func TestBackchannelLogoutWrongAudience(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "wrong-client-id", // Wrong audience + "iat": time.Now().Unix(), + "jti": "unique-id-wrong-aud", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-123", + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for wrong audience, got %d", rw.Code) + } +} + +// TestBackchannelLogoutExpiredToken tests that expired tokens are rejected +func TestBackchannelLogoutExpiredToken(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Token issued 20 minutes ago (> 15 min allowed) + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Add(-20 * time.Minute).Unix(), // Too old + "jti": "unique-id-expired", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-123", + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for expired token, got %d", rw.Code) + } +} + +// TestBackchannelLogoutFutureToken tests that future-dated tokens are rejected +func TestBackchannelLogoutFutureToken(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Token issued 10 minutes in the future (> 5 min clock skew allowed) + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Add(10 * time.Minute).Unix(), // Future + "jti": "unique-id-future", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-123", + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for future token, got %d", rw.Code) + } +} + +// TestBackchannelLogoutMissingEvents tests that missing events claim is rejected +func TestBackchannelLogoutMissingEvents(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Token without events claim + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Unix(), + "jti": "unique-id-no-events", + "sid": "session-123", + // No events claim + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing events, got %d", rw.Code) + } +} + +// TestBackchannelLogoutWrongEventType tests that wrong event type is rejected +func TestBackchannelLogoutWrongEventType(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Token with wrong event type + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Unix(), + "jti": "unique-id-wrong-event", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/wrong-event": map[string]interface{}{}, // Wrong event + }, + "sid": "session-123", + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for wrong event type, got %d", rw.Code) + } +} + +// TestBackchannelLogoutWithNonce tests that nonce presence is rejected +func TestBackchannelLogoutWithNonce(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Token with nonce (not allowed per spec) + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Unix(), + "jti": "unique-id-with-nonce", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-123", + "nonce": "should-not-be-here", // Nonce not allowed + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for token with nonce, got %d", rw.Code) + } +} + +// TestBackchannelLogoutRawJWTBody tests logout with raw JWT in body (not form-urlencoded) +func TestBackchannelLogoutRawJWTBody(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Unix(), + "jti": "unique-id-raw-body", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-raw-body", + }) + + // Send raw JWT in body (no form encoding) + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", strings.NewReader(logoutToken)) + req.Header.Set("Content-Type", "application/jwt") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", rw.Code, rw.Body.String()) + } + + // Verify session was invalidated + key := oidc.buildSessionInvalidationKey("sid", "session-raw-body") + if _, found := mockCache.data[key]; !found { + t.Error("Session should have been invalidated") + } +} + +// TestBackchannelLogoutArrayAudience tests logout with array audience claim +func TestBackchannelLogoutArrayAudience(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Array audience containing our client ID + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": []string{"other-client", "test-client", "another-client"}, + "iat": time.Now().Unix(), + "jti": "unique-id-array-aud", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-array-aud", + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", rw.Code, rw.Body.String()) + } +} + +// TestFrontchannelLogoutWithSubOnly tests front-channel logout with sub parameter only +func TestFrontchannelLogoutWithSubOnly(t *testing.T) { + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableFrontchannelLogout: true, + frontchannelLogoutPath: "/frontchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + } + + // Front-channel with sub parameter (some IdPs use this) + req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sub=user@example.com&iss=https://provider.example.com", nil) + rw := httptest.NewRecorder() + + oidc.handleFrontchannelLogout(rw, req) + + // Should fail because sid is required + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 (sid required), got %d", rw.Code) + } +} + +// TestFrontchannelLogoutCacheControl tests that front-channel logout sets proper cache headers +func TestFrontchannelLogoutCacheControl(t *testing.T) { + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableFrontchannelLogout: true, + frontchannelLogoutPath: "/frontchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + } + + req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123", nil) + rw := httptest.NewRecorder() + + oidc.handleFrontchannelLogout(rw, req) + + if rw.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rw.Code) + } + + // Check cache headers + cacheControl := rw.Header().Get("Cache-Control") + if !strings.Contains(cacheControl, "no-cache") || !strings.Contains(cacheControl, "no-store") { + t.Errorf("Expected Cache-Control to contain no-cache and no-store, got %s", cacheControl) + } + + pragma := rw.Header().Get("Pragma") + if pragma != "no-cache" { + t.Errorf("Expected Pragma: no-cache, got %s", pragma) + } + + // X-Frame-Options should be removed (to allow iframe embedding) + if rw.Header().Get("X-Frame-Options") != "" { + t.Error("X-Frame-Options should be removed for front-channel logout") + } +} + +// TestConcurrentSessionInvalidation tests concurrent session invalidations +func TestConcurrentSessionInvalidation(t *testing.T) { + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + sessionInvalidationCache: mockCache, + } + + // Invalidate multiple sessions concurrently + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(idx int) { + sid := fmt.Sprintf("session-%d", idx) + sub := fmt.Sprintf("user%d@example.com", idx) + err := oidc.invalidateSession(sid, sub) + if err != nil { + t.Errorf("Failed to invalidate session %d: %v", idx, err) + } + done <- true + }(i) + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + + // Verify all sessions were invalidated + for i := 0; i < 10; i++ { + sid := fmt.Sprintf("session-%d", i) + sub := fmt.Sprintf("user%d@example.com", i) + sidKey := oidc.buildSessionInvalidationKey("sid", sid) + subKey := oidc.buildSessionInvalidationKey("sub", sub) + if _, found := mockCache.Get(sidKey); !found { + t.Errorf("Session %d should have been invalidated by sid", i) + } + if _, found := mockCache.Get(subKey); !found { + t.Errorf("Session %d should have been invalidated by sub", i) + } + } +} + +// TestSessionInvalidationTimeComparison tests the time comparison logic +func TestSessionInvalidationTimeComparison(t *testing.T) { + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + sessionInvalidationCache: mockCache, + } + + // Create session at specific time + sessionCreatedAt := time.Now() + + // Wait a tiny bit and invalidate + time.Sleep(10 * time.Millisecond) + _ = oidc.invalidateSession("session-time-test", "") + + // Session created before invalidation should be invalidated + if !oidc.isSessionInvalidated("session-time-test", "", sessionCreatedAt) { + t.Error("Session created before invalidation should be marked as invalidated") + } + + // Session created after invalidation (simulated) should NOT be invalidated + futureSession := time.Now().Add(1 * time.Second) + if oidc.isSessionInvalidated("session-time-test", "", futureSession) { + t.Error("Session created after invalidation should NOT be marked as invalidated") + } +} + +// TestBackchannelLogoutMissingIat tests that missing iat is rejected +func TestBackchannelLogoutMissingIat(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Token without iat claim + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + // No iat + "jti": "unique-id-no-iat", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + "sid": "session-123", + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing iat, got %d", rw.Code) + } +} + +// TestBackchannelLogoutMissingSidAndSub tests that missing both sid and sub is rejected +func TestBackchannelLogoutMissingSidAndSub(t *testing.T) { + privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + mockCache := &mockCacheInterface{data: make(map[string]interface{})} + x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes()) + y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes()) + + mockJWKCache := &staticJWKCache{ + jwks: &JWKSet{ + Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}}, + }, + } + + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + enableBackchannelLogout: true, + backchannelLogoutPath: "/backchannel-logout", + sessionInvalidationCache: mockCache, + clientID: "test-client", + issuerURL: "https://provider.example.com", + jwkCache: mockJWKCache, + jwksURL: "https://provider.example.com/.well-known/jwks.json", + } + + // Token without sid or sub + logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{ + "iss": "https://provider.example.com", + "aud": "test-client", + "iat": time.Now().Unix(), + "jti": "unique-id-no-sid-sub", + "events": map[string]interface{}{ + "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{}, + }, + // No sid or sub + }) + + req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", + strings.NewReader("logout_token="+url.QueryEscape(logoutToken))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rw := httptest.NewRecorder() + + oidc.handleBackchannelLogout(rw, req) + + if rw.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for missing sid and sub, got %d", rw.Code) + } +} + +// createSignedLogoutToken is a helper to create properly signed logout tokens for testing +func createSignedLogoutToken(t *testing.T, privateKey *ecdsa.PrivateKey, claims map[string]interface{}) string { + t.Helper() + + header := map[string]interface{}{ + "alg": "ES256", + "typ": "logout+jwt", + "kid": "test-key-1", + } + headerJSON, _ := json.Marshal(header) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + + claimsJSON, _ := json.Marshal(claims) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + + // Sign the token + signingInput := headerB64 + "." + claimsB64 + hash := sha256.Sum256([]byte(signingInput)) + r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:]) + if err != nil { + t.Fatalf("Failed to sign token: %v", err) + } + + // Convert signature to fixed-size format (32 bytes each for P-256) + sigBytes := make([]byte, 64) + rBytes := r.Bytes() + sBytes := s.Bytes() + copy(sigBytes[32-len(rBytes):32], rBytes) + copy(sigBytes[64-len(sBytes):], sBytes) + signatureB64 := base64.RawURLEncoding.EncodeToString(sigBytes) + + return headerB64 + "." + claimsB64 + "." + signatureB64 +} diff --git a/main.go b/main.go index 84a63e8..de39983 100644 --- a/main.go +++ b/main.go @@ -212,16 +212,21 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name } return 60 * time.Second }(), - tokenCleanupStopChan: make(chan struct{}), - metadataRefreshStopChan: make(chan struct{}), - ctx: pluginCtx, - cancelFunc: cancelFunc, - suppressDiagnosticLogs: isTestMode(), - securityHeadersApplier: config.GetSecurityHeadersApplier(), - scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering - dcrConfig: config.DynamicClientRegistration, - allowPrivateIPAddresses: config.AllowPrivateIPAddresses, - minimalHeaders: config.MinimalHeaders, + tokenCleanupStopChan: make(chan struct{}), + metadataRefreshStopChan: make(chan struct{}), + ctx: pluginCtx, + cancelFunc: cancelFunc, + suppressDiagnosticLogs: isTestMode(), + securityHeadersApplier: config.GetSecurityHeadersApplier(), + scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering + dcrConfig: config.DynamicClientRegistration, + allowPrivateIPAddresses: config.AllowPrivateIPAddresses, + minimalHeaders: config.MinimalHeaders, + enableBackchannelLogout: config.EnableBackchannelLogout, + enableFrontchannelLogout: config.EnableFrontchannelLogout, + backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL), + frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL), + sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(), } // Log audience configuration diff --git a/middleware.go b/middleware.go index ebd36e8..5b7f742 100644 --- a/middleware.go +++ b/middleware.go @@ -26,6 +26,31 @@ import ( // - rw: The HTTP response writer. // - req: The incoming HTTP request. func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // Log request entry for debugging routing issues + t.logger.Debugf("Incoming request: %s %s", req.Method, req.URL.Path) + + // Handle logout requests early - before waiting for OIDC initialization + // This allows users to logout even if the OIDC provider is unavailable + if req.URL.Path == t.logoutURLPath { + t.logger.Debugf("Logout path matched early: %s", req.URL.Path) + t.handleLogout(rw, req) + return + } + + // Handle backchannel logout (IdP-initiated POST with logout_token) + if t.enableBackchannelLogout && t.backchannelLogoutPath != "" && req.URL.Path == t.backchannelLogoutPath { + t.logger.Debug("Backchannel logout path matched") + t.handleBackchannelLogout(rw, req) + return + } + + // Handle front-channel logout (IdP-initiated GET with sid/iss in iframe) + if t.enableFrontchannelLogout && t.frontchannelLogoutPath != "" && req.URL.Path == t.frontchannelLogoutPath { + t.logger.Debug("Front-channel logout path matched") + t.handleFrontchannelLogout(rw, req) + return + } + if !strings.HasPrefix(req.URL.Path, "/health") { t.firstRequestMutex.Lock() if !t.firstRequestReceived { @@ -42,6 +67,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.firstRequestMutex.Unlock() } + // Check excluded URLs before waiting for initialization + if t.determineExcludedURL(req.URL.Path) { + t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path) + t.next.ServeHTTP(rw, req) + return + } + + // Check for SSE requests before waiting for initialization + acceptHeader := req.Header.Get("Accept") + if strings.Contains(acceptHeader, "text/event-stream") { + t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader) + t.next.ServeHTTP(rw, req) + return + } + + // Log waiting for initialization to help diagnose hanging requests + t.logger.Debug("Waiting for OIDC provider initialization...") + select { case <-t.initComplete: // Read issuerURL with RLock @@ -83,7 +126,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.next.ServeHTTP(rw, req) return } - acceptHeader := req.Header.Get("Accept") + acceptHeader = req.Header.Get("Accept") if strings.Contains(acceptHeader, "text/event-stream") { t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader) // Set forwarded user headers from existing session before bypassing @@ -100,7 +143,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.next.ServeHTTP(rw, req) return } - t.sessionManager.CleanupOldCookies(rw, req) session, err := t.sessionManager.GetSession(req) @@ -131,10 +173,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { host := utils.DetermineHost(req) redirectURL := buildFullURL(scheme, host, t.redirURLPath) - if req.URL.Path == t.logoutURLPath { - t.handleLogout(rw, req) - return - } if req.URL.Path == t.redirURLPath { t.handleCallback(rw, req, redirectURL) return @@ -275,6 +313,24 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http return } + // Check if session has been invalidated via backchannel or front-channel logout + if t.enableBackchannelLogout || t.enableFrontchannelLogout { + idToken := session.GetIDToken() + if idToken != "" { + sid, sub, createdAt := t.extractSessionInfo(idToken) + if t.isSessionInvalidated(sid, sub, createdAt) { + t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email) + // Clear the session and redirect to login + if err := session.Clear(req, rw); err != nil { + t.logger.Errorf("Error clearing invalidated session: %v", err) + } + session.ResetRedirectCount() + t.defaultInitiateAuthentication(rw, req, session, redirectURL) + return + } + } + } + tokenForClaims := session.GetIDToken() if tokenForClaims == "" { tokenForClaims = session.GetAccessToken() diff --git a/middleware_edge_cases_test.go b/middleware_edge_cases_test.go index e0a265b..8deabf8 100644 --- a/middleware_edge_cases_test.go +++ b/middleware_edge_cases_test.go @@ -95,6 +95,38 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) { } } +// TestLogoutWorksWithoutOIDCInitialization tests that logout works even if OIDC provider is unavailable +// This is critical for allowing users to clear their session when the provider is down +func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) { + oidc := &TraefikOidc{ + logger: NewLogger("debug"), + initComplete: make(chan struct{}), // Never close to simulate provider unavailable + sessionManager: createTestSessionManager(t), + firstRequestReceived: true, + metadataRefreshStarted: true, + logoutURLPath: "/logout", + postLogoutRedirectURI: "/", + forceHTTPS: false, + } + // Note: initComplete is NOT closed, simulating OIDC provider being unavailable + + req := httptest.NewRequest("GET", "/logout", nil) + req.Host = "example.com" + rw := httptest.NewRecorder() + + oidc.ServeHTTP(rw, req) + + // Should redirect to post-logout URI even without OIDC initialization + if rw.Code != http.StatusFound { + t.Errorf("Expected redirect (302) for logout, got %d", rw.Code) + } + + location := rw.Header().Get("Location") + if location == "" { + t.Error("Expected Location header for logout redirect") + } +} + // TestMiddlewareDomainRestrictions tests domain-based access control // NOTE: Currently commented out due to complex session setup requirements // These scenarios are tested indirectly through integration tests diff --git a/settings.go b/settings.go index c17a664..ea5690e 100644 --- a/settings.go +++ b/settings.go @@ -65,6 +65,10 @@ type Config struct { ForceHTTPS bool `json:"forceHTTPS"` AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"` MinimalHeaders bool `json:"minimalHeaders,omitempty"` + EnableBackchannelLogout bool `json:"enableBackchannelLogout,omitempty"` + EnableFrontchannelLogout bool `json:"enableFrontchannelLogout,omitempty"` + BackchannelLogoutURL string `json:"backchannelLogoutURL,omitempty"` + FrontchannelLogoutURL string `json:"frontchannelLogoutURL,omitempty"` } // RedisConfig configures Redis cache backend settings for distributed caching. @@ -744,15 +748,6 @@ func newNoOpLogger() *Logger { // - code: The HTTP status code for the response. // - logger: The Logger instance to use for logging the error. // -// handleError writes an HTTP error response with the specified status code and message. -// It logs the error and sets appropriate headers before writing the response. -// -//lint:ignore U1000 Kept for potential future error handling -func handleError(w http.ResponseWriter, message string, code int, logger *Logger) { - logger.Error("%s", message) - http.Error(w, message, code) -} - // GetSecurityHeadersApplier returns a function that applies security headers func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) { if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled { @@ -1058,111 +1053,6 @@ func (rc *RedisConfig) ApplyEnvFallbacks() { } } -// LoadRedisConfigFromEnv loads Redis configuration from environment variables. -// Deprecated: Use RedisConfig.ApplyEnvFallbacks() on an existing config instead. -// This function is kept for backward compatibility but should not be used directly. -func LoadRedisConfigFromEnv() *RedisConfig { - // Check if Redis is enabled - enabledStr := os.Getenv("REDIS_ENABLED") - if enabledStr == "" || enabledStr == "false" || enabledStr == "0" { - return nil - } - - config := &RedisConfig{ - Enabled: true, - } - - // Parse numeric values - if dbStr := os.Getenv("REDIS_DB"); dbStr != "" { - if db, err := strconv.Atoi(dbStr); err == nil { - config.DB = db - } - } - - if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" { - if poolSize, err := strconv.Atoi(poolSizeStr); err == nil { - config.PoolSize = poolSize - } - } - - if connectTimeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); connectTimeoutStr != "" { - if timeout, err := strconv.Atoi(connectTimeoutStr); err == nil { - config.ConnectTimeout = timeout - } - } - - if readTimeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); readTimeoutStr != "" { - if timeout, err := strconv.Atoi(readTimeoutStr); err == nil { - config.ReadTimeout = timeout - } - } - - if writeTimeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeoutStr != "" { - if timeout, err := strconv.Atoi(writeTimeoutStr); err == nil { - config.WriteTimeout = timeout - } - } - - // Parse boolean values - if enableTLSStr := os.Getenv("REDIS_ENABLE_TLS"); enableTLSStr == "true" || enableTLSStr == "1" { - config.EnableTLS = true - } - - if skipVerifyStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipVerifyStr == "true" || skipVerifyStr == "1" { - config.TLSSkipVerify = true - } - - // Parse hybrid mode settings - if l1SizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); l1SizeStr != "" { - if size, err := strconv.Atoi(l1SizeStr); err == nil { - config.HybridL1Size = size - } - } - - if l1MemoryStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); l1MemoryStr != "" { - if memory, err := strconv.ParseInt(l1MemoryStr, 10, 64); err == nil { - config.HybridL1MemoryMB = memory - } - } - - // Parse circuit breaker settings - if enableCBStr := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCBStr == "false" || enableCBStr == "0" { - config.EnableCircuitBreaker = false - } else { - config.EnableCircuitBreaker = true // Default to enabled - } - - if cbThresholdStr := os.Getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD"); cbThresholdStr != "" { - if threshold, err := strconv.Atoi(cbThresholdStr); err == nil { - config.CircuitBreakerThreshold = threshold - } - } - - if cbTimeoutStr := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeoutStr != "" { - if timeout, err := strconv.Atoi(cbTimeoutStr); err == nil { - config.CircuitBreakerTimeout = timeout - } - } - - // Parse health check settings - if enableHCStr := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHCStr == "false" || enableHCStr == "0" { - config.EnableHealthCheck = false - } else { - config.EnableHealthCheck = true // Default to enabled - } - - if hcIntervalStr := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcIntervalStr != "" { - if interval, err := strconv.Atoi(hcIntervalStr); err == nil { - config.HealthCheckInterval = interval - } - } - - // Apply defaults after loading from env - config.ApplyDefaults() - - return config -} - func isOriginAllowed(origin string, allowedOrigins []string) bool { for _, allowed := range allowedOrigins { if origin == allowed || allowed == "*" { diff --git a/types.go b/types.go index 46ebe9d..759470a 100644 --- a/types.go +++ b/types.go @@ -119,6 +119,8 @@ type TraefikOidc struct { clientID string clientSecret string registrationURL string + backchannelLogoutPath string + frontchannelLogoutPath string scopesSupported []string scopes []string refreshGracePeriod time.Duration @@ -126,7 +128,10 @@ type TraefikOidc struct { shutdownOnce sync.Once metadataRetryMutex sync.Mutex firstRequestMutex sync.Mutex + sessionInvalidationCache CacheInterface minimalHeaders bool + enableBackchannelLogout bool + enableFrontchannelLogout bool firstRequestReceived bool requireTokenIntrospection bool metadataRefreshStarted bool diff --git a/universal_cache.go b/universal_cache.go index 3cb4dc1..3207fd0 100644 --- a/universal_cache.go +++ b/universal_cache.go @@ -720,22 +720,6 @@ func (c *UniversalCache) SetWithMetadata(key string, value interface{}, ttl time return nil } -// GetTyped retrieves a typed value from the cache -func GetTyped[T any](c *UniversalCache, key string) (T, bool) { - var zero T - value, exists := c.Get(key) - if !exists { - return zero, false - } - - typed, ok := value.(T) - if !ok { - return zero, false - } - - return typed, true -} - // TokenCacheOperations provides token-specific operations func (c *UniversalCache) BlacklistToken(token string, ttl time.Duration) error { if c.config.Type != CacheTypeToken { diff --git a/universal_cache_singleton.go b/universal_cache_singleton.go index 9f07e21..3ccdde9 100644 --- a/universal_cache_singleton.go +++ b/universal_cache_singleton.go @@ -13,21 +13,22 @@ import ( // It runs a single consolidated cleanup goroutine for all caches, reducing // goroutine count and CPU overhead compared to per-cache cleanup routines. type UniversalCacheManager struct { - sharedBackend backends.CacheBackend - ctx context.Context - tokenTypeCache *UniversalCache - jwkCache *UniversalCache - sessionCache *UniversalCache - introspectionCache *UniversalCache - tokenCache *UniversalCache - metadataCache *UniversalCache - dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments - logger *Logger - blacklistCache *UniversalCache - cancel context.CancelFunc - wg sync.WaitGroup - mu sync.RWMutex - cleanupStarted bool + sharedBackend backends.CacheBackend + ctx context.Context + tokenTypeCache *UniversalCache + jwkCache *UniversalCache + sessionCache *UniversalCache + introspectionCache *UniversalCache + tokenCache *UniversalCache + metadataCache *UniversalCache + dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments + sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout + logger *Logger + blacklistCache *UniversalCache + cancel context.CancelFunc + wg sync.WaitGroup + mu sync.RWMutex + cleanupStarted bool } var ( @@ -170,6 +171,16 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) { Logger: logger, SkipAutoCleanup: true, // Managed cleanup }) + + // Initialize session invalidation cache for backchannel/front-channel logout + // This cache stores invalidated session IDs and subjects to revoke sessions + manager.sessionInvalidationCache = NewUniversalCache(UniversalCacheConfig{ + Type: CacheTypeSession, + MaxSize: 5000, // Support many concurrent invalidations + DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h) + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup + }) } // initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration @@ -363,6 +374,19 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r createBackend("dcr"), ) + // Session invalidation cache - CRITICAL for distributed backchannel/front-channel logout + // Uses Redis backend to share session invalidations across all Traefik replicas + manager.sessionInvalidationCache = NewUniversalCacheWithBackend( + UniversalCacheConfig{ + Type: CacheTypeSession, + MaxSize: 5000, // Support many concurrent invalidations + DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h) + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup + }, + createBackend("session_invalidation"), + ) + logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode) } @@ -411,6 +435,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() { m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, + m.sessionInvalidationCache, } m.mu.RUnlock() @@ -452,13 +477,6 @@ func (m *UniversalCacheManager) GetJWKCache() *UniversalCache { return m.jwkCache } -// GetSessionCache returns the session cache -func (m *UniversalCacheManager) GetSessionCache() *UniversalCache { - m.mu.RLock() - defer m.mu.RUnlock() - return m.sessionCache -} - // GetIntrospectionCache returns the token introspection cache func (m *UniversalCacheManager) GetIntrospectionCache() *UniversalCache { m.mu.RLock() @@ -473,6 +491,13 @@ func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache { return m.tokenTypeCache } +// GetSessionInvalidationCache returns the session invalidation cache for backchannel/front-channel logout +func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache { + m.mu.RLock() + defer m.mu.RUnlock() + return m.sessionInvalidationCache +} + // GetDCRCredentialsCache returns the DCR credentials cache for distributed storage func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache { m.mu.RLock() @@ -495,7 +520,7 @@ func (m *UniversalCacheManager) Close() error { // Close all caches first (they won't close the shared backend) for _, cache := range []*UniversalCache{ - m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, + m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, } { if cache != nil { _ = cache.Close() // Safe to ignore: best effort cache cleanup @@ -516,35 +541,6 @@ func (m *UniversalCacheManager) Close() error { return nil } -// InitializeCacheManagerFromConfig initializes the cache manager with configuration -// This should be called early in the application startup with the loaded configuration -func InitializeCacheManagerFromConfig(config *Config) *UniversalCacheManager { - logger := NewLogger(config.LogLevel) - - // Initialize Redis config if not present - if config.Redis == nil { - config.Redis = &RedisConfig{} - } - - // Apply environment variable fallbacks for fields not set in config - // This allows env vars to be used as optional overrides only when - // the config field is not explicitly set through Traefik - config.Redis.ApplyEnvFallbacks() - - // Apply defaults after env fallbacks - config.Redis.ApplyDefaults() - - // Log cache backend selection - if config.Redis != nil && config.Redis.Enabled { - logger.Infof("Initializing cache backend with Redis: mode=%s, address=%s", - config.Redis.CacheMode, config.Redis.Address) - } else { - logger.Info("Initializing cache backend with memory-only mode") - } - - return GetUniversalCacheManagerWithConfig(logger, config.Redis) -} - // ResetUniversalCacheManagerForTesting resets the singleton for testing purposes only // This should only be called in test code to ensure proper cleanup between tests func ResetUniversalCacheManagerForTesting() {