diff --git a/.gitignore b/.gitignore
index 412cb1e..e4faada 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
docker/
.claude/*.out
*.test
+.leann/
diff --git a/.traefik.yml b/.traefik.yml
index 1967079..f3f65cc 100644
--- a/.traefik.yml
+++ b/.traefik.yml
@@ -1021,6 +1021,79 @@ configuration:
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
required: false
+ enableBackchannelLogout:
+ type: boolean
+ description: |
+ Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST).
+
+ When enabled, the middleware accepts logout tokens at the configured backchannelLogoutURL.
+ The IdP sends a signed JWT (logout_token) to notify the application that a user's session
+ should be terminated.
+
+ This implements the OIDC Back-Channel Logout 1.0 specification.
+ See: https://openid.net/specs/openid-connect-backchannel-1_0.html
+
+ Requirements:
+ - backchannelLogoutURL must be configured
+ - The IdP must be configured to send logout tokens to your backchannel URL
+ - Logout tokens are validated using the IdP's JWKS
+
+ Default: false
+ required: false
+
+ backchannelLogoutURL:
+ type: string
+ description: |
+ Path for receiving backchannel logout tokens from the IdP.
+
+ This endpoint receives POST requests with a logout_token JWT in the request body.
+ The token is validated against the IdP's JWKS and contains the session ID (sid)
+ and/or subject (sub) to invalidate.
+
+ Example: /backchannel-logout
+
+ The full URL to configure in your IdP would be:
+ https://your-app.example.com/backchannel-logout
+
+ Note: This path should be unique and not conflict with your application routes.
+ required: false
+
+ enableFrontchannelLogout:
+ type: boolean
+ description: |
+ Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe).
+
+ When enabled, the middleware accepts logout requests at the configured frontchannelLogoutURL.
+ The IdP embeds an iframe pointing to this URL when the user logs out, allowing the
+ application to clear the user's session.
+
+ This implements the OIDC Front-Channel Logout 1.0 specification.
+ See: https://openid.net/specs/openid-connect-frontchannel-1_0.html
+
+ Requirements:
+ - frontchannelLogoutURL must be configured
+ - The IdP must be configured with your front-channel logout URL
+ - Your CSP headers must allow being embedded in an iframe from the IdP
+
+ Default: false
+ required: false
+
+ frontchannelLogoutURL:
+ type: string
+ description: |
+ Path for receiving front-channel logout requests from the IdP.
+
+ This endpoint receives GET requests with optional sid (session ID) and iss (issuer)
+ query parameters. When called, it invalidates the user's session.
+
+ Example: /frontchannel-logout
+
+ The full URL to configure in your IdP would be:
+ https://your-app.example.com/frontchannel-logout
+
+ Note: This path should be unique and not conflict with your application routes.
+ required: false
+
headers:
type: array
description: |
diff --git a/README.md b/README.md
index ab6b10f..bfc026d 100644
--- a/README.md
+++ b/README.md
@@ -154,6 +154,10 @@ The middleware supports the following configuration options:
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
+| `enableBackchannelLogout` | Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST) | `false` | `true` |
+| `backchannelLogoutURL` | The path for receiving backchannel logout tokens from the IdP | none | `/backchannel-logout` |
+| `enableFrontchannelLogout` | Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe) | `false` | `true` |
+| `frontchannelLogoutURL` | The path for receiving front-channel logout requests from the IdP | none | `/frontchannel-logout` |
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
@@ -1148,6 +1152,50 @@ spec:
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
```
+### With IdP-Initiated Logout (Backchannel & Front-Channel)
+
+This plugin supports [OIDC Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) and [OIDC Front-Channel Logout](https://openid.net/specs/openid-connect-frontchannel-1_0.html) for IdP-initiated single logout.
+
+**Backchannel Logout** (recommended): The IdP sends a server-to-server POST request with a signed `logout_token` JWT when a user logs out.
+
+**Front-Channel Logout**: The IdP loads an iframe with the logout URL to invalidate the session in the browser.
+
+```yaml
+apiVersion: traefik.io/v1alpha1
+kind: Middleware
+metadata:
+ name: oidc-with-idp-logout
+ namespace: traefik
+spec:
+ plugin:
+ traefikoidc:
+ providerURL: https://auth.example.com
+ clientID: your-client-id
+ clientSecret: your-client-secret
+ sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
+ callbackURL: /oauth2/callback
+ logoutURL: /oauth2/logout # RP-initiated logout
+
+ # Backchannel Logout (server-to-server)
+ enableBackchannelLogout: true
+ backchannelLogoutURL: /backchannel-logout
+
+ # Front-Channel Logout (iframe-based)
+ enableFrontchannelLogout: true
+ frontchannelLogoutURL: /frontchannel-logout
+
+ # For multi-replica deployments, use Redis to share session invalidations
+ redis:
+ enabled: true
+ address: redis:6379
+```
+
+> **Note**: For multi-replica deployments, you **must** enable Redis to share session invalidation state across all instances. Otherwise, a logout on one instance won't invalidate sessions on other instances.
+
+**IdP Configuration**: Configure your IdP to send logout requests to:
+- **Backchannel**: `https://your-app.example.com/backchannel-logout` (POST with `logout_token`)
+- **Front-Channel**: `https://your-app.example.com/frontchannel-logout?sid=SESSION_ID&iss=ISSUER` (GET in iframe)
+
### With Templated Headers
```yaml
diff --git a/cache_manager.go b/cache_manager.go
index e3997c8..8b59d57 100644
--- a/cache_manager.go
+++ b/cache_manager.go
@@ -104,6 +104,14 @@ func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
}
+// GetSharedSessionInvalidationCache returns the shared session invalidation cache
+// for backchannel and front-channel logout (IdP-initiated logout)
+func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
+ cm.mu.RLock()
+ defer cm.mu.RUnlock()
+ return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
+}
+
// Close gracefully shuts down all cache components
func (cm *CacheManager) Close() error {
cm.mu.Lock()
diff --git a/docs/index.html b/docs/index.html
index 63d5c39..33a9bf1 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -90,6 +90,7 @@
Configuration
Deployment
Security
+ Logout
diff --git a/helpers.go b/helpers.go
index 85ffaa4..7a8b407 100644
--- a/helpers.go
+++ b/helpers.go
@@ -336,6 +336,7 @@ func createStringMap(keys []string) map[string]struct{} {
// and redirects to the provider's logout endpoint or configured post-logout URI.
// It handles potential errors during session retrieval or clearing.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
+ t.logger.Debug("Processing logout request")
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
diff --git a/logout.go b/logout.go
new file mode 100644
index 0000000..7c990c4
--- /dev/null
+++ b/logout.go
@@ -0,0 +1,502 @@
+// Package traefikoidc provides OIDC authentication middleware for Traefik.
+// This file implements OIDC Backchannel Logout (OpenID Connect Back-Channel Logout 1.0)
+// and Front-Channel Logout (OpenID Connect Front-Channel Logout 1.0) functionality.
+package traefikoidc
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+)
+
+const (
+ // logoutTokenType is the expected typ claim for logout tokens
+ // #nosec G101 -- This is a JWT type claim value from OIDC spec, not a credential
+ logoutTokenType = "logout+jwt"
+
+ // sessionInvalidationTTL is how long to remember invalidated sessions
+ // Should be at least as long as your session max age
+ sessionInvalidationTTL = 25 * time.Hour
+)
+
+// LogoutTokenClaims represents the claims in an OIDC logout token
+// as defined in OpenID Connect Back-Channel Logout 1.0
+type LogoutTokenClaims struct {
+ Issuer string `json:"iss"`
+ Subject string `json:"sub,omitempty"`
+ Audience interface{} `json:"aud"` // Can be string or []string
+ IssuedAt int64 `json:"iat"`
+ JTI string `json:"jti"`
+ Events map[string]interface{} `json:"events"`
+ SessionID string `json:"sid,omitempty"`
+ Nonce string `json:"nonce,omitempty"` // Must NOT be present
+}
+
+// handleBackchannelLogout processes OIDC Backchannel Logout requests.
+// It accepts POST requests with a logout_token parameter containing a JWT
+// that identifies which session(s) to terminate.
+//
+// According to OpenID Connect Back-Channel Logout 1.0:
+// - The logout_token is a JWT signed by the IdP
+// - It contains either a 'sid' (session ID) or 'sub' (subject) claim to identify the session
+// - The RP must validate the token and invalidate the matching session(s)
+//
+// Parameters:
+// - rw: The HTTP response writer
+// - req: The HTTP request containing the logout_token
+func (t *TraefikOidc) handleBackchannelLogout(rw http.ResponseWriter, req *http.Request) {
+ t.logger.Debug("Processing backchannel logout request")
+
+ // Backchannel logout must be POST
+ if req.Method != http.MethodPost {
+ t.logger.Errorf("Backchannel logout: invalid method %s, expected POST", req.Method)
+ http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Parse form data to get logout_token
+ if err := req.ParseForm(); err != nil {
+ t.logger.Errorf("Backchannel logout: failed to parse form: %v", err)
+ http.Error(rw, "Bad request", http.StatusBadRequest)
+ return
+ }
+
+ logoutToken := req.FormValue("logout_token")
+ if logoutToken == "" {
+ // Also try reading from request body as raw JWT
+ body, err := io.ReadAll(io.LimitReader(req.Body, 64*1024)) // 64KB limit
+ if err == nil && len(body) > 0 {
+ logoutToken = string(body)
+ }
+ }
+
+ if logoutToken == "" {
+ t.logger.Error("Backchannel logout: missing logout_token")
+ http.Error(rw, "logout_token required", http.StatusBadRequest)
+ return
+ }
+
+ // Parse and validate the logout token
+ claims, err := t.validateLogoutToken(logoutToken)
+ if err != nil {
+ t.logger.Errorf("Backchannel logout: token validation failed: %v", err)
+ // Return 400 for invalid token per spec
+ http.Error(rw, "Invalid logout token", http.StatusBadRequest)
+ return
+ }
+
+ // Invalidate session(s) based on sid or sub
+ if err := t.invalidateSession(claims.SessionID, claims.Subject); err != nil {
+ t.logger.Errorf("Backchannel logout: failed to invalidate session: %v", err)
+ http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
+ return
+ }
+
+ t.logger.Infof("Backchannel logout: successfully invalidated session (sid=%s, sub=%s)",
+ claims.SessionID, claims.Subject)
+
+ // Return 200 OK with empty body per spec
+ rw.WriteHeader(http.StatusOK)
+}
+
+// handleFrontchannelLogout processes OIDC Front-Channel Logout requests.
+// It accepts GET requests with 'iss' and 'sid' query parameters that identify
+// which session to terminate. The IdP typically loads this URL in an iframe.
+//
+// According to OpenID Connect Front-Channel Logout 1.0:
+// - The request contains 'iss' (issuer) and optionally 'sid' (session ID)
+// - The RP should clear the session and return a response (typically empty or image)
+// - The response must be cacheable to allow the IdP to load it in an iframe
+//
+// Parameters:
+// - rw: The HTTP response writer
+// - req: The HTTP request containing iss and sid parameters
+func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http.Request) {
+ t.logger.Debug("Processing front-channel logout request")
+
+ // Front-channel logout should be GET
+ if req.Method != http.MethodGet {
+ t.logger.Errorf("Front-channel logout: invalid method %s, expected GET", req.Method)
+ http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Get iss and sid from query parameters
+ iss := req.URL.Query().Get("iss")
+ sid := req.URL.Query().Get("sid")
+
+ // Validate issuer matches our expected issuer
+ t.metadataMu.RLock()
+ expectedIssuer := t.issuerURL
+ t.metadataMu.RUnlock()
+
+ if iss != "" && iss != expectedIssuer {
+ t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
+ http.Error(rw, "Invalid issuer", http.StatusBadRequest)
+ return
+ }
+
+ // Must have at least sid for front-channel logout
+ if sid == "" {
+ t.logger.Error("Front-channel logout: missing sid parameter")
+ http.Error(rw, "sid parameter required", http.StatusBadRequest)
+ return
+ }
+
+ // Invalidate the session
+ if err := t.invalidateSession(sid, ""); err != nil {
+ t.logger.Errorf("Front-channel logout: failed to invalidate session: %v", err)
+ http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
+ return
+ }
+
+ t.logger.Infof("Front-channel logout: successfully invalidated session (sid=%s)", sid)
+
+ // Return a minimal HTML response that's suitable for iframe loading
+ // Set headers to allow embedding and caching
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Header().Set("Cache-Control", "no-cache, no-store")
+ rw.Header().Set("Pragma", "no-cache")
+ // Allow embedding in iframes from any origin (required for front-channel logout)
+ rw.Header().Del("X-Frame-Options")
+ rw.WriteHeader(http.StatusOK)
+ _, _ = rw.Write([]byte("
Logged Out"))
+}
+
+// validateLogoutToken parses and validates a logout token JWT.
+// It verifies the token signature, issuer, audience, and required claims.
+//
+// Parameters:
+// - tokenString: The raw JWT logout token
+//
+// Returns:
+// - The parsed logout token claims
+// - An error if validation fails
+func (t *TraefikOidc) validateLogoutToken(tokenString string) (*LogoutTokenClaims, error) {
+ // Parse the JWT
+ jwt, err := parseJWT(tokenString)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse logout token: %w", err)
+ }
+
+ // Check token type if present
+ if typ, ok := jwt.Header["typ"].(string); ok {
+ // The typ should be "logout+jwt" or omitted
+ if typ != "" && typ != logoutTokenType && typ != "JWT" {
+ return nil, fmt.Errorf("invalid token type: %s", typ)
+ }
+ }
+
+ // Verify signature only (not standard claims - logout tokens don't have 'exp')
+ if err := t.verifyLogoutTokenSignature(jwt, tokenString); err != nil {
+ return nil, fmt.Errorf("signature verification failed: %w", err)
+ }
+
+ // Extract claims
+ claims := &LogoutTokenClaims{}
+ claimsJSON, err := json.Marshal(jwt.Claims)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal claims: %w", err)
+ }
+ if err := json.Unmarshal(claimsJSON, claims); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
+ }
+
+ // Validate required claims
+ t.metadataMu.RLock()
+ expectedIssuer := t.issuerURL
+ t.metadataMu.RUnlock()
+
+ // Validate issuer
+ if claims.Issuer != expectedIssuer {
+ return nil, fmt.Errorf("issuer mismatch: got %s, expected %s", claims.Issuer, expectedIssuer)
+ }
+
+ // Validate audience (must contain our client_id)
+ if !t.validateLogoutTokenAudience(claims.Audience) {
+ return nil, fmt.Errorf("audience validation failed")
+ }
+
+ // Validate iat (issued at) - must be present and not too old
+ if claims.IssuedAt == 0 {
+ return nil, fmt.Errorf("missing iat claim")
+ }
+ iatTime := time.Unix(claims.IssuedAt, 0)
+ // Allow up to 5 minutes clock skew and 10 minutes token age
+ if time.Since(iatTime) > 15*time.Minute {
+ return nil, fmt.Errorf("logout token too old: issued at %v", iatTime)
+ }
+ // Token should not be from the future (with 5 min clock skew tolerance)
+ if iatTime.After(time.Now().Add(5 * time.Minute)) {
+ return nil, fmt.Errorf("logout token issued in the future: %v", iatTime)
+ }
+
+ // Validate events claim - must contain the logout event
+ if claims.Events == nil {
+ return nil, fmt.Errorf("missing events claim")
+ }
+ if _, ok := claims.Events["http://schemas.openid.net/event/backchannel-logout"]; !ok {
+ return nil, fmt.Errorf("missing backchannel-logout event in events claim")
+ }
+
+ // Validate that nonce is NOT present (per spec)
+ if claims.Nonce != "" {
+ return nil, fmt.Errorf("nonce claim must not be present in logout token")
+ }
+
+ // Must have either sid or sub (or both)
+ if claims.SessionID == "" && claims.Subject == "" {
+ return nil, fmt.Errorf("logout token must contain either sid or sub claim")
+ }
+
+ return claims, nil
+}
+
+// validateLogoutTokenAudience checks if the logout token audience contains our client_id
+func (t *TraefikOidc) validateLogoutTokenAudience(aud interface{}) bool {
+ switch v := aud.(type) {
+ case string:
+ return v == t.clientID
+ case []interface{}:
+ for _, a := range v {
+ if s, ok := a.(string); ok && s == t.clientID {
+ return true
+ }
+ }
+ case []string:
+ for _, a := range v {
+ if a == t.clientID {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// verifyLogoutTokenSignature verifies only the signature of a logout token.
+// Unlike VerifyJWTSignatureAndClaims, this does NOT validate standard claims like 'exp'
+// because logout tokens don't have an expiration claim per OIDC Back-Channel Logout spec.
+//
+// Parameters:
+// - jwt: The parsed JWT structure
+// - tokenString: The raw token string for signature verification
+//
+// Returns:
+// - An error if signature verification fails
+func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) error {
+ t.logger.Debug("Verifying logout token signature")
+
+ // Read jwksURL with RLock
+ t.metadataMu.RLock()
+ jwksURL := t.jwksURL
+ t.metadataMu.RUnlock()
+
+ jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
+ if err != nil {
+ return fmt.Errorf("failed to get JWKS: %w", err)
+ }
+
+ if jwks == nil {
+ return fmt.Errorf("JWKS is nil, cannot verify token")
+ }
+
+ kid, ok := jwt.Header["kid"].(string)
+ if !ok || kid == "" {
+ return fmt.Errorf("missing key ID in token header")
+ }
+
+ alg, ok := jwt.Header["alg"].(string)
+ if !ok || alg == "" {
+ return fmt.Errorf("missing algorithm in token header")
+ }
+
+ // Find the matching key in JWKS
+ var matchingKey *JWK
+ for _, key := range jwks.Keys {
+ if key.Kid == kid {
+ matchingKey = &key
+ break
+ }
+ }
+
+ if matchingKey == nil {
+ return fmt.Errorf("no matching public key found for kid: %s", kid)
+ }
+
+ publicKeyPEM, err := jwkToPEM(matchingKey)
+ if err != nil {
+ return fmt.Errorf("failed to convert JWK to PEM: %w", err)
+ }
+
+ if err := verifySignature(tokenString, publicKeyPEM, alg); err != nil {
+ return fmt.Errorf("signature verification failed: %w", err)
+ }
+
+ t.logger.Debug("Logout token signature verified successfully")
+ return nil
+}
+
+// invalidateSession marks a session as invalidated in the session invalidation cache.
+// It stores entries by both sid and sub if available.
+//
+// Parameters:
+// - sid: The session ID to invalidate (from the 'sid' claim)
+// - sub: The subject to invalidate (from the 'sub' claim)
+//
+// Returns:
+// - An error if the invalidation fails
+func (t *TraefikOidc) invalidateSession(sid, sub string) error {
+ if t.sessionInvalidationCache == nil {
+ return fmt.Errorf("session invalidation cache not initialized")
+ }
+
+ now := time.Now().Unix()
+
+ // Store by session ID
+ if sid != "" {
+ key := t.buildSessionInvalidationKey("sid", sid)
+ t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
+ t.logger.Debugf("Invalidated session by sid: %s", sid)
+ }
+
+ // Store by subject (invalidates all sessions for this user)
+ if sub != "" {
+ key := t.buildSessionInvalidationKey("sub", sub)
+ t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
+ t.logger.Debugf("Invalidated session by sub: %s", sub)
+ }
+
+ return nil
+}
+
+// isSessionInvalidated checks if a session has been invalidated via backchannel
+// or front-channel logout.
+//
+// Parameters:
+// - sid: The session ID to check
+// - sub: The subject to check
+// - sessionCreatedAt: When the session was created (to compare against invalidation time)
+//
+// Returns:
+// - true if the session has been invalidated, false otherwise
+func (t *TraefikOidc) isSessionInvalidated(sid, sub string, sessionCreatedAt time.Time) bool {
+ if t.sessionInvalidationCache == nil {
+ return false
+ }
+
+ // Truncate session creation time to seconds for fair comparison with Unix timestamps
+ sessionCreatedAtSec := sessionCreatedAt.Truncate(time.Second)
+
+ // Check by session ID first (more specific)
+ if sid != "" {
+ key := t.buildSessionInvalidationKey("sid", sid)
+ if val, found := t.sessionInvalidationCache.Get(key); found {
+ if invalidatedAt, ok := val.(int64); ok {
+ // Session was invalidated at or after it was created
+ invalidationTime := time.Unix(invalidatedAt, 0)
+ if !invalidationTime.Before(sessionCreatedAtSec) {
+ t.logger.Debugf("Session invalidated by sid: %s", sid)
+ return true
+ }
+ }
+ }
+ }
+
+ // Check by subject (all sessions for this user)
+ if sub != "" {
+ key := t.buildSessionInvalidationKey("sub", sub)
+ if val, found := t.sessionInvalidationCache.Get(key); found {
+ if invalidatedAt, ok := val.(int64); ok {
+ // Sessions for this subject created at or before invalidation are invalid
+ invalidationTime := time.Unix(invalidatedAt, 0)
+ if !invalidationTime.Before(sessionCreatedAtSec) {
+ t.logger.Debugf("Session invalidated by sub: %s", sub)
+ return true
+ }
+ }
+ }
+ }
+
+ return false
+}
+
+// buildSessionInvalidationKey creates a cache key for session invalidation
+func (t *TraefikOidc) buildSessionInvalidationKey(keyType, value string) string {
+ return fmt.Sprintf("session_invalidation:%s:%s", keyType, value)
+}
+
+// extractSessionInfo extracts sid and sub from an ID token for session tracking
+func (t *TraefikOidc) extractSessionInfo(idToken string) (sid, sub string, createdAt time.Time) {
+ if idToken == "" {
+ return "", "", time.Time{}
+ }
+
+ jwt, err := parseJWT(idToken)
+ if err != nil {
+ return "", "", time.Time{}
+ }
+
+ // Extract sid (session ID)
+ if sidVal, ok := jwt.Claims["sid"].(string); ok {
+ sid = sidVal
+ }
+
+ // Extract sub (subject)
+ if subVal, ok := jwt.Claims["sub"].(string); ok {
+ sub = subVal
+ }
+
+ // Extract iat for session creation time
+ if iatVal, ok := jwt.Claims["iat"].(float64); ok {
+ createdAt = time.Unix(int64(iatVal), 0)
+ } else {
+ // Default to now if iat not present
+ createdAt = time.Now()
+ }
+
+ return sid, sub, createdAt
+}
+
+// determineLogoutPath checks if the given path matches any logout URL
+func (t *TraefikOidc) determineLogoutPath(path string) string {
+ // Check backchannel logout path
+ if t.backchannelLogoutPath != "" && path == t.backchannelLogoutPath {
+ return "backchannel"
+ }
+
+ // Check front-channel logout path
+ if t.frontchannelLogoutPath != "" && path == t.frontchannelLogoutPath {
+ return "frontchannel"
+ }
+
+ // Check regular logout path (for RP-initiated logout)
+ if path == t.logoutURLPath {
+ return "rp"
+ }
+
+ return ""
+}
+
+// normalizeLogoutPath ensures logout paths start with / and prevents open redirects
+func normalizeLogoutPath(path string) string {
+ if path == "" {
+ return ""
+ }
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ // Prevent open redirect: ensure second character is not / or \
+ // This prevents URLs like //example.com or /\example.com from being treated as absolute URLs
+ if len(path) > 1 && (path[1] == '/' || path[1] == '\\') {
+ // Strip leading slashes/backslashes and re-normalize
+ path = strings.TrimLeft(path, "/\\")
+ if path != "" {
+ path = "/" + path
+ }
+ }
+ return path
+}
diff --git a/logout_test.go b/logout_test.go
new file mode 100644
index 0000000..4f9b359
--- /dev/null
+++ b/logout_test.go
@@ -0,0 +1,1623 @@
+package traefikoidc
+
+import (
+ "context"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+// TestBackchannelLogoutBasic tests the basic backchannel logout flow
+func TestBackchannelLogoutBasic(t *testing.T) {
+ // Create a mock cache for session invalidation
+ mockCache := &mockCacheInterface{}
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ }
+
+ tests := []struct {
+ name string
+ method string
+ body string
+ contentType string
+ expectedStatus int
+ }{
+ {
+ name: "GET method not allowed",
+ method: http.MethodGet,
+ body: "",
+ contentType: "",
+ expectedStatus: http.StatusMethodNotAllowed,
+ },
+ {
+ name: "Missing logout_token",
+ method: http.MethodPost,
+ body: "",
+ contentType: "application/x-www-form-urlencoded",
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "Invalid logout_token format",
+ method: http.MethodPost,
+ body: "logout_token=not-a-valid-jwt",
+ contentType: "application/x-www-form-urlencoded",
+ expectedStatus: http.StatusBadRequest,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(tc.method, "/backchannel-logout", strings.NewReader(tc.body))
+ if tc.contentType != "" {
+ req.Header.Set("Content-Type", tc.contentType)
+ }
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != tc.expectedStatus {
+ t.Errorf("Expected status %d, got %d", tc.expectedStatus, rw.Code)
+ }
+ })
+ }
+}
+
+// TestFrontchannelLogoutBasic tests the basic front-channel logout flow
+func TestFrontchannelLogoutBasic(t *testing.T) {
+ // Create a mock cache for session invalidation
+ mockCache := &mockCacheInterface{}
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableFrontchannelLogout: true,
+ frontchannelLogoutPath: "/frontchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ }
+
+ tests := []struct {
+ name string
+ method string
+ queryParams map[string]string
+ expectedStatus int
+ }{
+ {
+ name: "POST method not allowed",
+ method: http.MethodPost,
+ queryParams: map[string]string{},
+ expectedStatus: http.StatusMethodNotAllowed,
+ },
+ {
+ name: "Missing sid parameter",
+ method: http.MethodGet,
+ queryParams: map[string]string{"iss": "https://provider.example.com"},
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "Invalid issuer",
+ method: http.MethodGet,
+ queryParams: map[string]string{"iss": "https://wrong-issuer.com", "sid": "session123"},
+ expectedStatus: http.StatusBadRequest,
+ },
+ {
+ name: "Valid front-channel logout",
+ method: http.MethodGet,
+ queryParams: map[string]string{"iss": "https://provider.example.com", "sid": "session123"},
+ expectedStatus: http.StatusOK,
+ },
+ {
+ name: "Valid front-channel logout without issuer",
+ method: http.MethodGet,
+ queryParams: map[string]string{"sid": "session456"},
+ expectedStatus: http.StatusOK,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ urlStr := "/frontchannel-logout"
+ if len(tc.queryParams) > 0 {
+ params := url.Values{}
+ for k, v := range tc.queryParams {
+ params.Set(k, v)
+ }
+ urlStr += "?" + params.Encode()
+ }
+
+ req := httptest.NewRequest(tc.method, urlStr, nil)
+ rw := httptest.NewRecorder()
+
+ oidc.handleFrontchannelLogout(rw, req)
+
+ if rw.Code != tc.expectedStatus {
+ t.Errorf("Expected status %d, got %d", tc.expectedStatus, rw.Code)
+ }
+
+ // For successful logout, verify response headers
+ if tc.expectedStatus == http.StatusOK {
+ // Should not have X-Frame-Options (to allow iframe embedding)
+ if rw.Header().Get("X-Frame-Options") != "" {
+ t.Error("Expected X-Frame-Options to be removed for front-channel logout")
+ }
+ // Should have HTML content type
+ contentType := rw.Header().Get("Content-Type")
+ if !strings.Contains(contentType, "text/html") {
+ t.Errorf("Expected HTML content type, got %s", contentType)
+ }
+ }
+ })
+ }
+}
+
+// TestSessionInvalidation tests session invalidation storage and retrieval
+func TestSessionInvalidation(t *testing.T) {
+ mockCache := &mockCacheInterface{
+ data: make(map[string]interface{}),
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ sessionInvalidationCache: mockCache,
+ }
+
+ // Test invalidating by session ID
+ err := oidc.invalidateSession("session123", "")
+ if err != nil {
+ t.Fatalf("Failed to invalidate session by sid: %v", err)
+ }
+
+ // Verify the session was invalidated
+ key := oidc.buildSessionInvalidationKey("sid", "session123")
+ if _, found := mockCache.data[key]; !found {
+ t.Error("Session invalidation by sid was not stored")
+ }
+
+ // Test invalidating by subject
+ err = oidc.invalidateSession("", "user@example.com")
+ if err != nil {
+ t.Fatalf("Failed to invalidate session by sub: %v", err)
+ }
+
+ // Verify the subject was invalidated
+ key = oidc.buildSessionInvalidationKey("sub", "user@example.com")
+ if _, found := mockCache.data[key]; !found {
+ t.Error("Session invalidation by sub was not stored")
+ }
+}
+
+// TestIsSessionInvalidated tests checking if a session is invalidated
+func TestIsSessionInvalidated(t *testing.T) {
+ mockCache := &mockCacheInterface{
+ data: make(map[string]interface{}),
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ sessionInvalidationCache: mockCache,
+ }
+
+ // Session created now
+ sessionCreatedAt := time.Now()
+
+ // Initially, session should not be invalidated
+ if oidc.isSessionInvalidated("session123", "user@example.com", sessionCreatedAt) {
+ t.Error("Session should not be invalidated initially")
+ }
+
+ // Invalidate the session
+ _ = oidc.invalidateSession("session123", "")
+
+ // Now session should be invalidated
+ if !oidc.isSessionInvalidated("session123", "", sessionCreatedAt) {
+ t.Error("Session should be invalidated after invalidateSession call")
+ }
+
+ // Session created after invalidation should not be affected
+ futureSession := time.Now().Add(1 * time.Hour)
+ if oidc.isSessionInvalidated("session123", "", futureSession) {
+ t.Error("Session created after invalidation should not be affected")
+ }
+}
+
+// TestLogoutTokenValidation tests logout token claim validation
+func TestLogoutTokenValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ claims *LogoutTokenClaims
+ expectError bool
+ errorMsg string
+ }{
+ {
+ name: "Missing events claim",
+ claims: &LogoutTokenClaims{
+ Issuer: "https://provider.example.com",
+ Audience: "test-client",
+ IssuedAt: time.Now().Unix(),
+ SessionID: "session123",
+ },
+ expectError: true,
+ errorMsg: "missing events claim",
+ },
+ {
+ name: "Missing both sid and sub",
+ claims: &LogoutTokenClaims{
+ Issuer: "https://provider.example.com",
+ Audience: "test-client",
+ IssuedAt: time.Now().Unix(),
+ Events: map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ },
+ expectError: true,
+ errorMsg: "must contain either sid or sub",
+ },
+ {
+ name: "Nonce present (not allowed)",
+ claims: &LogoutTokenClaims{
+ Issuer: "https://provider.example.com",
+ Audience: "test-client",
+ IssuedAt: time.Now().Unix(),
+ SessionID: "session123",
+ Nonce: "should-not-be-here",
+ Events: map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ },
+ expectError: true,
+ errorMsg: "nonce claim must not be present",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ // We can't directly test validateLogoutToken without a properly signed JWT,
+ // but we can verify the validation logic through the claims struct
+ if tc.claims.Events == nil && tc.expectError && strings.Contains(tc.errorMsg, "events") {
+ // Events validation would fail
+ }
+ if tc.claims.SessionID == "" && tc.claims.Subject == "" && tc.expectError && strings.Contains(tc.errorMsg, "sid or sub") {
+ // sid/sub validation would fail
+ }
+ if tc.claims.Nonce != "" && tc.expectError && strings.Contains(tc.errorMsg, "nonce") {
+ // nonce validation would fail
+ }
+ })
+ }
+}
+
+// TestLogoutTokenAudienceValidation tests audience validation for logout tokens
+func TestLogoutTokenAudienceValidation(t *testing.T) {
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ clientID: "test-client",
+ }
+
+ tests := []struct {
+ name string
+ audience interface{}
+ valid bool
+ }{
+ {
+ name: "String audience matching client ID",
+ audience: "test-client",
+ valid: true,
+ },
+ {
+ name: "String audience not matching",
+ audience: "other-client",
+ valid: false,
+ },
+ {
+ name: "Array audience containing client ID",
+ audience: []interface{}{"other-client", "test-client"},
+ valid: true,
+ },
+ {
+ name: "Array audience not containing client ID",
+ audience: []interface{}{"other-client", "another-client"},
+ valid: false,
+ },
+ {
+ name: "String array audience containing client ID",
+ audience: []string{"other-client", "test-client"},
+ valid: true,
+ },
+ {
+ name: "Nil audience",
+ audience: nil,
+ valid: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ result := oidc.validateLogoutTokenAudience(tc.audience)
+ if result != tc.valid {
+ t.Errorf("Expected %v, got %v", tc.valid, result)
+ }
+ })
+ }
+}
+
+// TestExtractSessionInfo tests extraction of session info from ID tokens
+func TestExtractSessionInfo(t *testing.T) {
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ }
+
+ // Test with empty token
+ sid, sub, createdAt := oidc.extractSessionInfo("")
+ if sid != "" || sub != "" || !createdAt.IsZero() {
+ t.Error("Empty token should return empty values")
+ }
+
+ // Test with invalid token
+ sid, sub, createdAt = oidc.extractSessionInfo("not-a-valid-jwt")
+ if sid != "" || sub != "" || !createdAt.IsZero() {
+ t.Error("Invalid token should return empty values")
+ }
+
+ // Create a simple unsigned JWT for testing (header.claims.signature)
+ header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
+ now := time.Now().Unix()
+ claimsJSON := fmt.Sprintf(`{"sid":"test-session-id","sub":"test-subject","iat":%d}`, now)
+ claims := base64.RawURLEncoding.EncodeToString([]byte(claimsJSON))
+ testToken := header + "." + claims + "."
+
+ sid, sub, createdAt = oidc.extractSessionInfo(testToken)
+ if sid != "test-session-id" {
+ t.Errorf("Expected sid 'test-session-id', got '%s'", sid)
+ }
+ if sub != "test-subject" {
+ t.Errorf("Expected sub 'test-subject', got '%s'", sub)
+ }
+ if createdAt.Unix() != now {
+ t.Errorf("Expected createdAt %d, got %d", now, createdAt.Unix())
+ }
+}
+
+// TestMiddlewareBackchannelLogoutRouting tests that backchannel logout requests are routed correctly
+func TestMiddlewareBackchannelLogoutRouting(t *testing.T) {
+ mockCache := &mockCacheInterface{
+ data: make(map[string]interface{}),
+ }
+
+ nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("next handler called"))
+ })
+
+ oidc := &TraefikOidc{
+ next: nextHandler,
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ initComplete: make(chan struct{}),
+ firstRequestReceived: true,
+ metadataRefreshStarted: true,
+ logoutURLPath: "/logout",
+ }
+ close(oidc.initComplete)
+
+ // Request to backchannel logout path should be handled
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", nil)
+ rw := httptest.NewRecorder()
+
+ oidc.ServeHTTP(rw, req)
+
+ // Should return 400 (bad request) because no logout_token provided
+ // but importantly should NOT call next handler
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for missing logout_token, got %d", rw.Code)
+ }
+ if strings.Contains(rw.Body.String(), "next handler called") {
+ t.Error("Backchannel logout should not call next handler")
+ }
+}
+
+// TestMiddlewareFrontchannelLogoutRouting tests that front-channel logout requests are routed correctly
+func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
+ mockCache := &mockCacheInterface{
+ data: make(map[string]interface{}),
+ }
+
+ nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("next handler called"))
+ })
+
+ oidc := &TraefikOidc{
+ next: nextHandler,
+ logger: NewLogger("debug"),
+ enableFrontchannelLogout: true,
+ frontchannelLogoutPath: "/frontchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ initComplete: make(chan struct{}),
+ firstRequestReceived: true,
+ metadataRefreshStarted: true,
+ logoutURLPath: "/logout",
+ }
+ close(oidc.initComplete)
+
+ // Request to front-channel logout path with valid sid should succeed
+ req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=test-session", nil)
+ rw := httptest.NewRecorder()
+
+ oidc.ServeHTTP(rw, req)
+
+ // Should return 200 OK
+ if rw.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", rw.Code)
+ }
+ if strings.Contains(rw.Body.String(), "next handler called") {
+ t.Error("Front-channel logout should not call next handler")
+ }
+}
+
+// TestNormalizeLogoutPath tests the path normalization function
+func TestNormalizeLogoutPath(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {"", ""},
+ {"/logout", "/logout"},
+ {"logout", "/logout"},
+ {"/backchannel-logout", "/backchannel-logout"},
+ {"backchannel-logout", "/backchannel-logout"},
+ // Security: prevent open redirect via //
+ {"//evil.com", "/evil.com"},
+ {"//evil.com/path", "/evil.com/path"},
+ // Security: prevent open redirect via /\
+ {"/\\evil.com", "/evil.com"},
+ {"/\\evil.com/path", "/evil.com/path"},
+ // Security: multiple leading slashes
+ {"///example.com", "/example.com"},
+ // Security: mixed slashes
+ {"//\\example.com", "/example.com"},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.input, func(t *testing.T) {
+ result := normalizeLogoutPath(tc.input)
+ if result != tc.expected {
+ t.Errorf("normalizeLogoutPath(%q) = %q, expected %q", tc.input, result, tc.expected)
+ }
+ })
+ }
+}
+
+// mockCacheInterface implements CacheInterface for testing
+type mockCacheInterface struct {
+ mu sync.Mutex
+ data map[string]interface{}
+}
+
+func (m *mockCacheInterface) Set(key string, value interface{}, ttl time.Duration) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.data == nil {
+ m.data = make(map[string]interface{})
+ }
+ m.data[key] = value
+}
+
+func (m *mockCacheInterface) Get(key string) (interface{}, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.data == nil {
+ return nil, false
+ }
+ val, found := m.data[key]
+ return val, found
+}
+
+func (m *mockCacheInterface) Delete(key string) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.data != nil {
+ delete(m.data, key)
+ }
+}
+
+func (m *mockCacheInterface) SetMaxSize(size int) {}
+func (m *mockCacheInterface) Size() int {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return len(m.data)
+}
+func (m *mockCacheInterface) Clear() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.data = make(map[string]interface{})
+}
+func (m *mockCacheInterface) Cleanup() {}
+func (m *mockCacheInterface) Close() {}
+func (m *mockCacheInterface) GetStats() map[string]interface{} {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return map[string]interface{}{"size": len(m.data)}
+}
+
+// TestBackchannelLogoutWithValidToken tests backchannel logout with a properly formatted (but unsigned) token
+func TestBackchannelLogoutWithValidToken(t *testing.T) {
+ // This test verifies the token parsing and validation logic
+ // Note: In production, the token would need to be properly signed by the IdP
+ mockCache := &mockCacheInterface{
+ data: make(map[string]interface{}),
+ }
+
+ // Create mock JWK cache that returns keys
+ mockJWKCache := &mockJWKCacheForLogout{}
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ }
+
+ // Create a minimal logout token structure (this won't pass signature verification
+ // but tests the parsing logic)
+ header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"ES256","typ":"logout+jwt"}`))
+ now := time.Now().Unix()
+ claimsJSON := fmt.Sprintf(`{
+ "iss":"https://provider.example.com",
+ "aud":"test-client",
+ "iat":%d,
+ "jti":"unique-id-123",
+ "events":{"http://schemas.openid.net/event/backchannel-logout":{}},
+ "sid":"session-to-logout"
+ }`, now)
+ claims := base64.RawURLEncoding.EncodeToString([]byte(claimsJSON))
+ logoutToken := header + "." + claims + ".fake-signature"
+
+ // This should fail because of invalid signature, but we can verify
+ // the token parsing works up to signature verification
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ // Should fail with 400 due to signature verification failure
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400, got %d", rw.Code)
+ }
+}
+
+// mockJWKCacheForLogout implements JWKCacheInterface for testing
+type mockJWKCacheForLogout struct{}
+
+func (m *mockJWKCacheForLogout) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
+ // Generate a test ECDSA key pair
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+
+ // Convert public key to JWK format
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ return &JWKSet{
+ Keys: []JWK{
+ {
+ Kty: "EC",
+ Crv: "P-256",
+ X: x,
+ Y: y,
+ Kid: "test-key-1",
+ Use: "sig",
+ Alg: "ES256",
+ },
+ },
+ }, nil
+}
+
+func (m *mockJWKCacheForLogout) Clear() {}
+func (m *mockJWKCacheForLogout) Cleanup() {}
+func (m *mockJWKCacheForLogout) Close() {}
+
+// TestBackchannelLogoutIntegration tests the full backchannel logout flow with a properly signed token
+func TestBackchannelLogoutIntegration(t *testing.T) {
+ // Generate ECDSA key pair for signing
+ privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatalf("Failed to generate key: %v", err)
+ }
+
+ mockCache := &mockCacheInterface{
+ data: make(map[string]interface{}),
+ }
+
+ // Create JWK cache that returns our test key
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{
+ {
+ Kty: "EC",
+ Crv: "P-256",
+ X: x,
+ Y: y,
+ Kid: "test-key-1",
+ Use: "sig",
+ Alg: "ES256",
+ },
+ },
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Create and sign a valid logout token
+ header := map[string]interface{}{
+ "alg": "ES256",
+ "typ": "logout+jwt",
+ "kid": "test-key-1",
+ }
+ headerJSON, _ := json.Marshal(header)
+ headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
+
+ now := time.Now().Unix()
+ claims := map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": now,
+ "jti": "unique-id-123",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-to-logout",
+ }
+ claimsJSON, _ := json.Marshal(claims)
+ claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
+
+ // Sign the token
+ signingInput := headerB64 + "." + claimsB64
+ hash := sha256.Sum256([]byte(signingInput))
+ r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:])
+ if err != nil {
+ t.Fatalf("Failed to sign token: %v", err)
+ }
+
+ // Convert signature to fixed-size format (32 bytes each for P-256)
+ rBytes := r.Bytes()
+ sBytes := s.Bytes()
+ sigBytes := make([]byte, 64)
+ copy(sigBytes[32-len(rBytes):32], rBytes)
+ copy(sigBytes[64-len(sBytes):], sBytes)
+ signatureB64 := base64.RawURLEncoding.EncodeToString(sigBytes)
+
+ logoutToken := headerB64 + "." + claimsB64 + "." + signatureB64
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ // Should succeed with 200 OK
+ if rw.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d: %s", rw.Code, rw.Body.String())
+ }
+
+ // Verify session was invalidated
+ key := oidc.buildSessionInvalidationKey("sid", "session-to-logout")
+ if _, found := mockCache.data[key]; !found {
+ t.Error("Session should have been invalidated")
+ }
+}
+
+// staticJWKCache returns a static JWKS for testing
+type staticJWKCache struct {
+ jwks *JWKSet
+}
+
+func (s *staticJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
+ return s.jwks, nil
+}
+
+func (s *staticJWKCache) Clear() {}
+func (s *staticJWKCache) Cleanup() {}
+func (s *staticJWKCache) Close() {}
+
+// TestDetermineLogoutPath tests the logout path determination function
+func TestDetermineLogoutPath(t *testing.T) {
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ logoutURLPath: "/logout",
+ backchannelLogoutPath: "/backchannel-logout",
+ frontchannelLogoutPath: "/frontchannel-logout",
+ }
+
+ tests := []struct {
+ path string
+ expected string
+ }{
+ {"/logout", "rp"},
+ {"/backchannel-logout", "backchannel"},
+ {"/frontchannel-logout", "frontchannel"},
+ {"/api/resource", ""},
+ {"/", ""},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.path, func(t *testing.T) {
+ result := oidc.determineLogoutPath(tc.path)
+ if result != tc.expected {
+ t.Errorf("determineLogoutPath(%q) = %q, expected %q", tc.path, result, tc.expected)
+ }
+ })
+ }
+}
+
+// TestSessionInvalidationWithNilCache tests that session invalidation handles nil cache gracefully
+func TestSessionInvalidationWithNilCache(t *testing.T) {
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ sessionInvalidationCache: nil,
+ }
+
+ // Should return error for nil cache
+ err := oidc.invalidateSession("session123", "")
+ if err == nil {
+ t.Error("Expected error for nil cache")
+ }
+
+ // isSessionInvalidated should return false for nil cache
+ if oidc.isSessionInvalidated("session123", "", time.Now()) {
+ t.Error("Expected false for nil cache")
+ }
+}
+
+// TestBackchannelLogoutWithSubOnly tests logout with subject claim only (no sid)
+func TestBackchannelLogoutWithSubOnly(t *testing.T) {
+ privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatalf("Failed to generate key: %v", err)
+ }
+
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-sub-only",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sub": "user@example.com", // Only sub, no sid
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d: %s", rw.Code, rw.Body.String())
+ }
+
+ // Verify subject was invalidated
+ key := oidc.buildSessionInvalidationKey("sub", "user@example.com")
+ if _, found := mockCache.data[key]; !found {
+ t.Error("Subject should have been invalidated")
+ }
+}
+
+// TestBackchannelLogoutWithBothSidAndSub tests logout with both sid and sub claims
+func TestBackchannelLogoutWithBothSidAndSub(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-both",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-123",
+ "sub": "user@example.com",
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", rw.Code)
+ }
+
+ // Both sid and sub should be invalidated
+ sidKey := oidc.buildSessionInvalidationKey("sid", "session-123")
+ subKey := oidc.buildSessionInvalidationKey("sub", "user@example.com")
+ if _, found := mockCache.data[sidKey]; !found {
+ t.Error("Session ID should have been invalidated")
+ }
+ if _, found := mockCache.data[subKey]; !found {
+ t.Error("Subject should have been invalidated")
+ }
+}
+
+// TestBackchannelLogoutWrongIssuer tests that wrong issuer is rejected
+func TestBackchannelLogoutWrongIssuer(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://wrong-issuer.com", // Wrong issuer
+ "aud": "test-client",
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-wrong-iss",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-123",
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for wrong issuer, got %d", rw.Code)
+ }
+}
+
+// TestBackchannelLogoutWrongAudience tests that wrong audience is rejected
+func TestBackchannelLogoutWrongAudience(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "wrong-client-id", // Wrong audience
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-wrong-aud",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-123",
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for wrong audience, got %d", rw.Code)
+ }
+}
+
+// TestBackchannelLogoutExpiredToken tests that expired tokens are rejected
+func TestBackchannelLogoutExpiredToken(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Token issued 20 minutes ago (> 15 min allowed)
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Add(-20 * time.Minute).Unix(), // Too old
+ "jti": "unique-id-expired",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-123",
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for expired token, got %d", rw.Code)
+ }
+}
+
+// TestBackchannelLogoutFutureToken tests that future-dated tokens are rejected
+func TestBackchannelLogoutFutureToken(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Token issued 10 minutes in the future (> 5 min clock skew allowed)
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Add(10 * time.Minute).Unix(), // Future
+ "jti": "unique-id-future",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-123",
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for future token, got %d", rw.Code)
+ }
+}
+
+// TestBackchannelLogoutMissingEvents tests that missing events claim is rejected
+func TestBackchannelLogoutMissingEvents(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Token without events claim
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-no-events",
+ "sid": "session-123",
+ // No events claim
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for missing events, got %d", rw.Code)
+ }
+}
+
+// TestBackchannelLogoutWrongEventType tests that wrong event type is rejected
+func TestBackchannelLogoutWrongEventType(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Token with wrong event type
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-wrong-event",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/wrong-event": map[string]interface{}{}, // Wrong event
+ },
+ "sid": "session-123",
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for wrong event type, got %d", rw.Code)
+ }
+}
+
+// TestBackchannelLogoutWithNonce tests that nonce presence is rejected
+func TestBackchannelLogoutWithNonce(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Token with nonce (not allowed per spec)
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-with-nonce",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-123",
+ "nonce": "should-not-be-here", // Nonce not allowed
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for token with nonce, got %d", rw.Code)
+ }
+}
+
+// TestBackchannelLogoutRawJWTBody tests logout with raw JWT in body (not form-urlencoded)
+func TestBackchannelLogoutRawJWTBody(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-raw-body",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-raw-body",
+ })
+
+ // Send raw JWT in body (no form encoding)
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout", strings.NewReader(logoutToken))
+ req.Header.Set("Content-Type", "application/jwt")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d: %s", rw.Code, rw.Body.String())
+ }
+
+ // Verify session was invalidated
+ key := oidc.buildSessionInvalidationKey("sid", "session-raw-body")
+ if _, found := mockCache.data[key]; !found {
+ t.Error("Session should have been invalidated")
+ }
+}
+
+// TestBackchannelLogoutArrayAudience tests logout with array audience claim
+func TestBackchannelLogoutArrayAudience(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Array audience containing our client ID
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": []string{"other-client", "test-client", "another-client"},
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-array-aud",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-array-aud",
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d: %s", rw.Code, rw.Body.String())
+ }
+}
+
+// TestFrontchannelLogoutWithSubOnly tests front-channel logout with sub parameter only
+func TestFrontchannelLogoutWithSubOnly(t *testing.T) {
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableFrontchannelLogout: true,
+ frontchannelLogoutPath: "/frontchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ }
+
+ // Front-channel with sub parameter (some IdPs use this)
+ req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sub=user@example.com&iss=https://provider.example.com", nil)
+ rw := httptest.NewRecorder()
+
+ oidc.handleFrontchannelLogout(rw, req)
+
+ // Should fail because sid is required
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 (sid required), got %d", rw.Code)
+ }
+}
+
+// TestFrontchannelLogoutCacheControl tests that front-channel logout sets proper cache headers
+func TestFrontchannelLogoutCacheControl(t *testing.T) {
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableFrontchannelLogout: true,
+ frontchannelLogoutPath: "/frontchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/frontchannel-logout?sid=session123", nil)
+ rw := httptest.NewRecorder()
+
+ oidc.handleFrontchannelLogout(rw, req)
+
+ if rw.Code != http.StatusOK {
+ t.Errorf("Expected status 200, got %d", rw.Code)
+ }
+
+ // Check cache headers
+ cacheControl := rw.Header().Get("Cache-Control")
+ if !strings.Contains(cacheControl, "no-cache") || !strings.Contains(cacheControl, "no-store") {
+ t.Errorf("Expected Cache-Control to contain no-cache and no-store, got %s", cacheControl)
+ }
+
+ pragma := rw.Header().Get("Pragma")
+ if pragma != "no-cache" {
+ t.Errorf("Expected Pragma: no-cache, got %s", pragma)
+ }
+
+ // X-Frame-Options should be removed (to allow iframe embedding)
+ if rw.Header().Get("X-Frame-Options") != "" {
+ t.Error("X-Frame-Options should be removed for front-channel logout")
+ }
+}
+
+// TestConcurrentSessionInvalidation tests concurrent session invalidations
+func TestConcurrentSessionInvalidation(t *testing.T) {
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ sessionInvalidationCache: mockCache,
+ }
+
+ // Invalidate multiple sessions concurrently
+ done := make(chan bool, 10)
+ for i := 0; i < 10; i++ {
+ go func(idx int) {
+ sid := fmt.Sprintf("session-%d", idx)
+ sub := fmt.Sprintf("user%d@example.com", idx)
+ err := oidc.invalidateSession(sid, sub)
+ if err != nil {
+ t.Errorf("Failed to invalidate session %d: %v", idx, err)
+ }
+ done <- true
+ }(i)
+ }
+
+ // Wait for all goroutines
+ for i := 0; i < 10; i++ {
+ <-done
+ }
+
+ // Verify all sessions were invalidated
+ for i := 0; i < 10; i++ {
+ sid := fmt.Sprintf("session-%d", i)
+ sub := fmt.Sprintf("user%d@example.com", i)
+ sidKey := oidc.buildSessionInvalidationKey("sid", sid)
+ subKey := oidc.buildSessionInvalidationKey("sub", sub)
+ if _, found := mockCache.Get(sidKey); !found {
+ t.Errorf("Session %d should have been invalidated by sid", i)
+ }
+ if _, found := mockCache.Get(subKey); !found {
+ t.Errorf("Session %d should have been invalidated by sub", i)
+ }
+ }
+}
+
+// TestSessionInvalidationTimeComparison tests the time comparison logic
+func TestSessionInvalidationTimeComparison(t *testing.T) {
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ sessionInvalidationCache: mockCache,
+ }
+
+ // Create session at specific time
+ sessionCreatedAt := time.Now()
+
+ // Wait a tiny bit and invalidate
+ time.Sleep(10 * time.Millisecond)
+ _ = oidc.invalidateSession("session-time-test", "")
+
+ // Session created before invalidation should be invalidated
+ if !oidc.isSessionInvalidated("session-time-test", "", sessionCreatedAt) {
+ t.Error("Session created before invalidation should be marked as invalidated")
+ }
+
+ // Session created after invalidation (simulated) should NOT be invalidated
+ futureSession := time.Now().Add(1 * time.Second)
+ if oidc.isSessionInvalidated("session-time-test", "", futureSession) {
+ t.Error("Session created after invalidation should NOT be marked as invalidated")
+ }
+}
+
+// TestBackchannelLogoutMissingIat tests that missing iat is rejected
+func TestBackchannelLogoutMissingIat(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Token without iat claim
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ // No iat
+ "jti": "unique-id-no-iat",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ "sid": "session-123",
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for missing iat, got %d", rw.Code)
+ }
+}
+
+// TestBackchannelLogoutMissingSidAndSub tests that missing both sid and sub is rejected
+func TestBackchannelLogoutMissingSidAndSub(t *testing.T) {
+ privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ mockCache := &mockCacheInterface{data: make(map[string]interface{})}
+ x := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.X.Bytes())
+ y := base64.RawURLEncoding.EncodeToString(privateKey.PublicKey.Y.Bytes())
+
+ mockJWKCache := &staticJWKCache{
+ jwks: &JWKSet{
+ Keys: []JWK{{Kty: "EC", Crv: "P-256", X: x, Y: y, Kid: "test-key-1", Use: "sig", Alg: "ES256"}},
+ },
+ }
+
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ enableBackchannelLogout: true,
+ backchannelLogoutPath: "/backchannel-logout",
+ sessionInvalidationCache: mockCache,
+ clientID: "test-client",
+ issuerURL: "https://provider.example.com",
+ jwkCache: mockJWKCache,
+ jwksURL: "https://provider.example.com/.well-known/jwks.json",
+ }
+
+ // Token without sid or sub
+ logoutToken := createSignedLogoutToken(t, privateKey, map[string]interface{}{
+ "iss": "https://provider.example.com",
+ "aud": "test-client",
+ "iat": time.Now().Unix(),
+ "jti": "unique-id-no-sid-sub",
+ "events": map[string]interface{}{
+ "http://schemas.openid.net/event/backchannel-logout": map[string]interface{}{},
+ },
+ // No sid or sub
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/backchannel-logout",
+ strings.NewReader("logout_token="+url.QueryEscape(logoutToken)))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rw := httptest.NewRecorder()
+
+ oidc.handleBackchannelLogout(rw, req)
+
+ if rw.Code != http.StatusBadRequest {
+ t.Errorf("Expected status 400 for missing sid and sub, got %d", rw.Code)
+ }
+}
+
+// createSignedLogoutToken is a helper to create properly signed logout tokens for testing
+func createSignedLogoutToken(t *testing.T, privateKey *ecdsa.PrivateKey, claims map[string]interface{}) string {
+ t.Helper()
+
+ header := map[string]interface{}{
+ "alg": "ES256",
+ "typ": "logout+jwt",
+ "kid": "test-key-1",
+ }
+ headerJSON, _ := json.Marshal(header)
+ headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
+
+ claimsJSON, _ := json.Marshal(claims)
+ claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
+
+ // Sign the token
+ signingInput := headerB64 + "." + claimsB64
+ hash := sha256.Sum256([]byte(signingInput))
+ r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:])
+ if err != nil {
+ t.Fatalf("Failed to sign token: %v", err)
+ }
+
+ // Convert signature to fixed-size format (32 bytes each for P-256)
+ sigBytes := make([]byte, 64)
+ rBytes := r.Bytes()
+ sBytes := s.Bytes()
+ copy(sigBytes[32-len(rBytes):32], rBytes)
+ copy(sigBytes[64-len(sBytes):], sBytes)
+ signatureB64 := base64.RawURLEncoding.EncodeToString(sigBytes)
+
+ return headerB64 + "." + claimsB64 + "." + signatureB64
+}
diff --git a/main.go b/main.go
index 84a63e8..de39983 100644
--- a/main.go
+++ b/main.go
@@ -212,16 +212,21 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}
return 60 * time.Second
}(),
- tokenCleanupStopChan: make(chan struct{}),
- metadataRefreshStopChan: make(chan struct{}),
- ctx: pluginCtx,
- cancelFunc: cancelFunc,
- suppressDiagnosticLogs: isTestMode(),
- securityHeadersApplier: config.GetSecurityHeadersApplier(),
- scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
- dcrConfig: config.DynamicClientRegistration,
- allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
- minimalHeaders: config.MinimalHeaders,
+ tokenCleanupStopChan: make(chan struct{}),
+ metadataRefreshStopChan: make(chan struct{}),
+ ctx: pluginCtx,
+ cancelFunc: cancelFunc,
+ suppressDiagnosticLogs: isTestMode(),
+ securityHeadersApplier: config.GetSecurityHeadersApplier(),
+ scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
+ dcrConfig: config.DynamicClientRegistration,
+ allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
+ minimalHeaders: config.MinimalHeaders,
+ enableBackchannelLogout: config.EnableBackchannelLogout,
+ enableFrontchannelLogout: config.EnableFrontchannelLogout,
+ backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
+ frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
+ sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
}
// Log audience configuration
diff --git a/middleware.go b/middleware.go
index ebd36e8..5b7f742 100644
--- a/middleware.go
+++ b/middleware.go
@@ -26,6 +26,31 @@ import (
// - rw: The HTTP response writer.
// - req: The incoming HTTP request.
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ // Log request entry for debugging routing issues
+ t.logger.Debugf("Incoming request: %s %s", req.Method, req.URL.Path)
+
+ // Handle logout requests early - before waiting for OIDC initialization
+ // This allows users to logout even if the OIDC provider is unavailable
+ if req.URL.Path == t.logoutURLPath {
+ t.logger.Debugf("Logout path matched early: %s", req.URL.Path)
+ t.handleLogout(rw, req)
+ return
+ }
+
+ // Handle backchannel logout (IdP-initiated POST with logout_token)
+ if t.enableBackchannelLogout && t.backchannelLogoutPath != "" && req.URL.Path == t.backchannelLogoutPath {
+ t.logger.Debug("Backchannel logout path matched")
+ t.handleBackchannelLogout(rw, req)
+ return
+ }
+
+ // Handle front-channel logout (IdP-initiated GET with sid/iss in iframe)
+ if t.enableFrontchannelLogout && t.frontchannelLogoutPath != "" && req.URL.Path == t.frontchannelLogoutPath {
+ t.logger.Debug("Front-channel logout path matched")
+ t.handleFrontchannelLogout(rw, req)
+ return
+ }
+
if !strings.HasPrefix(req.URL.Path, "/health") {
t.firstRequestMutex.Lock()
if !t.firstRequestReceived {
@@ -42,6 +67,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.firstRequestMutex.Unlock()
}
+ // Check excluded URLs before waiting for initialization
+ if t.determineExcludedURL(req.URL.Path) {
+ t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
+ t.next.ServeHTTP(rw, req)
+ return
+ }
+
+ // Check for SSE requests before waiting for initialization
+ acceptHeader := req.Header.Get("Accept")
+ if strings.Contains(acceptHeader, "text/event-stream") {
+ t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
+ t.next.ServeHTTP(rw, req)
+ return
+ }
+
+ // Log waiting for initialization to help diagnose hanging requests
+ t.logger.Debug("Waiting for OIDC provider initialization...")
+
select {
case <-t.initComplete:
// Read issuerURL with RLock
@@ -83,7 +126,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.next.ServeHTTP(rw, req)
return
}
- acceptHeader := req.Header.Get("Accept")
+ acceptHeader = req.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/event-stream") {
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
// Set forwarded user headers from existing session before bypassing
@@ -100,7 +143,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.next.ServeHTTP(rw, req)
return
}
-
t.sessionManager.CleanupOldCookies(rw, req)
session, err := t.sessionManager.GetSession(req)
@@ -131,10 +173,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
- if req.URL.Path == t.logoutURLPath {
- t.handleLogout(rw, req)
- return
- }
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
return
@@ -275,6 +313,24 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
return
}
+ // Check if session has been invalidated via backchannel or front-channel logout
+ if t.enableBackchannelLogout || t.enableFrontchannelLogout {
+ idToken := session.GetIDToken()
+ if idToken != "" {
+ sid, sub, createdAt := t.extractSessionInfo(idToken)
+ if t.isSessionInvalidated(sid, sub, createdAt) {
+ t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
+ // Clear the session and redirect to login
+ if err := session.Clear(req, rw); err != nil {
+ t.logger.Errorf("Error clearing invalidated session: %v", err)
+ }
+ session.ResetRedirectCount()
+ t.defaultInitiateAuthentication(rw, req, session, redirectURL)
+ return
+ }
+ }
+ }
+
tokenForClaims := session.GetIDToken()
if tokenForClaims == "" {
tokenForClaims = session.GetAccessToken()
diff --git a/middleware_edge_cases_test.go b/middleware_edge_cases_test.go
index e0a265b..8deabf8 100644
--- a/middleware_edge_cases_test.go
+++ b/middleware_edge_cases_test.go
@@ -95,6 +95,38 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
}
}
+// TestLogoutWorksWithoutOIDCInitialization tests that logout works even if OIDC provider is unavailable
+// This is critical for allowing users to clear their session when the provider is down
+func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
+ oidc := &TraefikOidc{
+ logger: NewLogger("debug"),
+ initComplete: make(chan struct{}), // Never close to simulate provider unavailable
+ sessionManager: createTestSessionManager(t),
+ firstRequestReceived: true,
+ metadataRefreshStarted: true,
+ logoutURLPath: "/logout",
+ postLogoutRedirectURI: "/",
+ forceHTTPS: false,
+ }
+ // Note: initComplete is NOT closed, simulating OIDC provider being unavailable
+
+ req := httptest.NewRequest("GET", "/logout", nil)
+ req.Host = "example.com"
+ rw := httptest.NewRecorder()
+
+ oidc.ServeHTTP(rw, req)
+
+ // Should redirect to post-logout URI even without OIDC initialization
+ if rw.Code != http.StatusFound {
+ t.Errorf("Expected redirect (302) for logout, got %d", rw.Code)
+ }
+
+ location := rw.Header().Get("Location")
+ if location == "" {
+ t.Error("Expected Location header for logout redirect")
+ }
+}
+
// TestMiddlewareDomainRestrictions tests domain-based access control
// NOTE: Currently commented out due to complex session setup requirements
// These scenarios are tested indirectly through integration tests
diff --git a/settings.go b/settings.go
index c17a664..ea5690e 100644
--- a/settings.go
+++ b/settings.go
@@ -65,6 +65,10 @@ type Config struct {
ForceHTTPS bool `json:"forceHTTPS"`
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
+ EnableBackchannelLogout bool `json:"enableBackchannelLogout,omitempty"`
+ EnableFrontchannelLogout bool `json:"enableFrontchannelLogout,omitempty"`
+ BackchannelLogoutURL string `json:"backchannelLogoutURL,omitempty"`
+ FrontchannelLogoutURL string `json:"frontchannelLogoutURL,omitempty"`
}
// RedisConfig configures Redis cache backend settings for distributed caching.
@@ -744,15 +748,6 @@ func newNoOpLogger() *Logger {
// - code: The HTTP status code for the response.
// - logger: The Logger instance to use for logging the error.
//
-// handleError writes an HTTP error response with the specified status code and message.
-// It logs the error and sets appropriate headers before writing the response.
-//
-//lint:ignore U1000 Kept for potential future error handling
-func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
- logger.Error("%s", message)
- http.Error(w, message, code)
-}
-
// GetSecurityHeadersApplier returns a function that applies security headers
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
@@ -1058,111 +1053,6 @@ func (rc *RedisConfig) ApplyEnvFallbacks() {
}
}
-// LoadRedisConfigFromEnv loads Redis configuration from environment variables.
-// Deprecated: Use RedisConfig.ApplyEnvFallbacks() on an existing config instead.
-// This function is kept for backward compatibility but should not be used directly.
-func LoadRedisConfigFromEnv() *RedisConfig {
- // Check if Redis is enabled
- enabledStr := os.Getenv("REDIS_ENABLED")
- if enabledStr == "" || enabledStr == "false" || enabledStr == "0" {
- return nil
- }
-
- config := &RedisConfig{
- Enabled: true,
- }
-
- // Parse numeric values
- if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
- if db, err := strconv.Atoi(dbStr); err == nil {
- config.DB = db
- }
- }
-
- if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
- if poolSize, err := strconv.Atoi(poolSizeStr); err == nil {
- config.PoolSize = poolSize
- }
- }
-
- if connectTimeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); connectTimeoutStr != "" {
- if timeout, err := strconv.Atoi(connectTimeoutStr); err == nil {
- config.ConnectTimeout = timeout
- }
- }
-
- if readTimeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); readTimeoutStr != "" {
- if timeout, err := strconv.Atoi(readTimeoutStr); err == nil {
- config.ReadTimeout = timeout
- }
- }
-
- if writeTimeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeoutStr != "" {
- if timeout, err := strconv.Atoi(writeTimeoutStr); err == nil {
- config.WriteTimeout = timeout
- }
- }
-
- // Parse boolean values
- if enableTLSStr := os.Getenv("REDIS_ENABLE_TLS"); enableTLSStr == "true" || enableTLSStr == "1" {
- config.EnableTLS = true
- }
-
- if skipVerifyStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipVerifyStr == "true" || skipVerifyStr == "1" {
- config.TLSSkipVerify = true
- }
-
- // Parse hybrid mode settings
- if l1SizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); l1SizeStr != "" {
- if size, err := strconv.Atoi(l1SizeStr); err == nil {
- config.HybridL1Size = size
- }
- }
-
- if l1MemoryStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); l1MemoryStr != "" {
- if memory, err := strconv.ParseInt(l1MemoryStr, 10, 64); err == nil {
- config.HybridL1MemoryMB = memory
- }
- }
-
- // Parse circuit breaker settings
- if enableCBStr := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCBStr == "false" || enableCBStr == "0" {
- config.EnableCircuitBreaker = false
- } else {
- config.EnableCircuitBreaker = true // Default to enabled
- }
-
- if cbThresholdStr := os.Getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD"); cbThresholdStr != "" {
- if threshold, err := strconv.Atoi(cbThresholdStr); err == nil {
- config.CircuitBreakerThreshold = threshold
- }
- }
-
- if cbTimeoutStr := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeoutStr != "" {
- if timeout, err := strconv.Atoi(cbTimeoutStr); err == nil {
- config.CircuitBreakerTimeout = timeout
- }
- }
-
- // Parse health check settings
- if enableHCStr := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHCStr == "false" || enableHCStr == "0" {
- config.EnableHealthCheck = false
- } else {
- config.EnableHealthCheck = true // Default to enabled
- }
-
- if hcIntervalStr := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcIntervalStr != "" {
- if interval, err := strconv.Atoi(hcIntervalStr); err == nil {
- config.HealthCheckInterval = interval
- }
- }
-
- // Apply defaults after loading from env
- config.ApplyDefaults()
-
- return config
-}
-
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if origin == allowed || allowed == "*" {
diff --git a/types.go b/types.go
index 46ebe9d..759470a 100644
--- a/types.go
+++ b/types.go
@@ -119,6 +119,8 @@ type TraefikOidc struct {
clientID string
clientSecret string
registrationURL string
+ backchannelLogoutPath string
+ frontchannelLogoutPath string
scopesSupported []string
scopes []string
refreshGracePeriod time.Duration
@@ -126,7 +128,10 @@ type TraefikOidc struct {
shutdownOnce sync.Once
metadataRetryMutex sync.Mutex
firstRequestMutex sync.Mutex
+ sessionInvalidationCache CacheInterface
minimalHeaders bool
+ enableBackchannelLogout bool
+ enableFrontchannelLogout bool
firstRequestReceived bool
requireTokenIntrospection bool
metadataRefreshStarted bool
diff --git a/universal_cache.go b/universal_cache.go
index 3cb4dc1..3207fd0 100644
--- a/universal_cache.go
+++ b/universal_cache.go
@@ -720,22 +720,6 @@ func (c *UniversalCache) SetWithMetadata(key string, value interface{}, ttl time
return nil
}
-// GetTyped retrieves a typed value from the cache
-func GetTyped[T any](c *UniversalCache, key string) (T, bool) {
- var zero T
- value, exists := c.Get(key)
- if !exists {
- return zero, false
- }
-
- typed, ok := value.(T)
- if !ok {
- return zero, false
- }
-
- return typed, true
-}
-
// TokenCacheOperations provides token-specific operations
func (c *UniversalCache) BlacklistToken(token string, ttl time.Duration) error {
if c.config.Type != CacheTypeToken {
diff --git a/universal_cache_singleton.go b/universal_cache_singleton.go
index 9f07e21..3ccdde9 100644
--- a/universal_cache_singleton.go
+++ b/universal_cache_singleton.go
@@ -13,21 +13,22 @@ import (
// It runs a single consolidated cleanup goroutine for all caches, reducing
// goroutine count and CPU overhead compared to per-cache cleanup routines.
type UniversalCacheManager struct {
- sharedBackend backends.CacheBackend
- ctx context.Context
- tokenTypeCache *UniversalCache
- jwkCache *UniversalCache
- sessionCache *UniversalCache
- introspectionCache *UniversalCache
- tokenCache *UniversalCache
- metadataCache *UniversalCache
- dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
- logger *Logger
- blacklistCache *UniversalCache
- cancel context.CancelFunc
- wg sync.WaitGroup
- mu sync.RWMutex
- cleanupStarted bool
+ sharedBackend backends.CacheBackend
+ ctx context.Context
+ tokenTypeCache *UniversalCache
+ jwkCache *UniversalCache
+ sessionCache *UniversalCache
+ introspectionCache *UniversalCache
+ tokenCache *UniversalCache
+ metadataCache *UniversalCache
+ dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
+ sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
+ logger *Logger
+ blacklistCache *UniversalCache
+ cancel context.CancelFunc
+ wg sync.WaitGroup
+ mu sync.RWMutex
+ cleanupStarted bool
}
var (
@@ -170,6 +171,16 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
})
+
+ // Initialize session invalidation cache for backchannel/front-channel logout
+ // This cache stores invalidated session IDs and subjects to revoke sessions
+ manager.sessionInvalidationCache = NewUniversalCache(UniversalCacheConfig{
+ Type: CacheTypeSession,
+ MaxSize: 5000, // Support many concurrent invalidations
+ DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
+ Logger: logger,
+ SkipAutoCleanup: true, // Managed cleanup
+ })
}
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
@@ -363,6 +374,19 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
createBackend("dcr"),
)
+ // Session invalidation cache - CRITICAL for distributed backchannel/front-channel logout
+ // Uses Redis backend to share session invalidations across all Traefik replicas
+ manager.sessionInvalidationCache = NewUniversalCacheWithBackend(
+ UniversalCacheConfig{
+ Type: CacheTypeSession,
+ MaxSize: 5000, // Support many concurrent invalidations
+ DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
+ Logger: logger,
+ SkipAutoCleanup: true, // Managed cleanup
+ },
+ createBackend("session_invalidation"),
+ )
+
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
}
@@ -411,6 +435,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
m.introspectionCache,
m.tokenTypeCache,
m.dcrCredentialsCache,
+ m.sessionInvalidationCache,
}
m.mu.RUnlock()
@@ -452,13 +477,6 @@ func (m *UniversalCacheManager) GetJWKCache() *UniversalCache {
return m.jwkCache
}
-// GetSessionCache returns the session cache
-func (m *UniversalCacheManager) GetSessionCache() *UniversalCache {
- m.mu.RLock()
- defer m.mu.RUnlock()
- return m.sessionCache
-}
-
// GetIntrospectionCache returns the token introspection cache
func (m *UniversalCacheManager) GetIntrospectionCache() *UniversalCache {
m.mu.RLock()
@@ -473,6 +491,13 @@ func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache {
return m.tokenTypeCache
}
+// GetSessionInvalidationCache returns the session invalidation cache for backchannel/front-channel logout
+func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return m.sessionInvalidationCache
+}
+
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
m.mu.RLock()
@@ -495,7 +520,7 @@ func (m *UniversalCacheManager) Close() error {
// Close all caches first (they won't close the shared backend)
for _, cache := range []*UniversalCache{
- m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache,
+ m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache,
} {
if cache != nil {
_ = cache.Close() // Safe to ignore: best effort cache cleanup
@@ -516,35 +541,6 @@ func (m *UniversalCacheManager) Close() error {
return nil
}
-// InitializeCacheManagerFromConfig initializes the cache manager with configuration
-// This should be called early in the application startup with the loaded configuration
-func InitializeCacheManagerFromConfig(config *Config) *UniversalCacheManager {
- logger := NewLogger(config.LogLevel)
-
- // Initialize Redis config if not present
- if config.Redis == nil {
- config.Redis = &RedisConfig{}
- }
-
- // Apply environment variable fallbacks for fields not set in config
- // This allows env vars to be used as optional overrides only when
- // the config field is not explicitly set through Traefik
- config.Redis.ApplyEnvFallbacks()
-
- // Apply defaults after env fallbacks
- config.Redis.ApplyDefaults()
-
- // Log cache backend selection
- if config.Redis != nil && config.Redis.Enabled {
- logger.Infof("Initializing cache backend with Redis: mode=%s, address=%s",
- config.Redis.CacheMode, config.Redis.Address)
- } else {
- logger.Info("Initializing cache backend with memory-only mode")
- }
-
- return GetUniversalCacheManagerWithConfig(logger, config.Redis)
-}
-
// ResetUniversalCacheManagerForTesting resets the singleton for testing purposes only
// This should only be called in test code to ensure proper cleanup between tests
func ResetUniversalCacheManagerForTesting() {