mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 57724918fe | |||
| 775de2ada1 | |||
| 7816e05c98 | |||
| 8bf7998150 | |||
| 22c4323fcb | |||
| 06b219d1f8 |
@@ -1,3 +1,4 @@
|
||||
docker/
|
||||
.claude/*.out
|
||||
*.test
|
||||
.leann/
|
||||
|
||||
+304
@@ -1021,6 +1021,79 @@ configuration:
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
|
||||
required: false
|
||||
|
||||
enableBackchannelLogout:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST).
|
||||
|
||||
When enabled, the middleware accepts logout tokens at the configured backchannelLogoutURL.
|
||||
The IdP sends a signed JWT (logout_token) to notify the application that a user's session
|
||||
should be terminated.
|
||||
|
||||
This implements the OIDC Back-Channel Logout 1.0 specification.
|
||||
See: https://openid.net/specs/openid-connect-backchannel-1_0.html
|
||||
|
||||
Requirements:
|
||||
- backchannelLogoutURL must be configured
|
||||
- The IdP must be configured to send logout tokens to your backchannel URL
|
||||
- Logout tokens are validated using the IdP's JWKS
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
backchannelLogoutURL:
|
||||
type: string
|
||||
description: |
|
||||
Path for receiving backchannel logout tokens from the IdP.
|
||||
|
||||
This endpoint receives POST requests with a logout_token JWT in the request body.
|
||||
The token is validated against the IdP's JWKS and contains the session ID (sid)
|
||||
and/or subject (sub) to invalidate.
|
||||
|
||||
Example: /backchannel-logout
|
||||
|
||||
The full URL to configure in your IdP would be:
|
||||
https://your-app.example.com/backchannel-logout
|
||||
|
||||
Note: This path should be unique and not conflict with your application routes.
|
||||
required: false
|
||||
|
||||
enableFrontchannelLogout:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe).
|
||||
|
||||
When enabled, the middleware accepts logout requests at the configured frontchannelLogoutURL.
|
||||
The IdP embeds an iframe pointing to this URL when the user logs out, allowing the
|
||||
application to clear the user's session.
|
||||
|
||||
This implements the OIDC Front-Channel Logout 1.0 specification.
|
||||
See: https://openid.net/specs/openid-connect-frontchannel-1_0.html
|
||||
|
||||
Requirements:
|
||||
- frontchannelLogoutURL must be configured
|
||||
- The IdP must be configured with your front-channel logout URL
|
||||
- Your CSP headers must allow being embedded in an iframe from the IdP
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
frontchannelLogoutURL:
|
||||
type: string
|
||||
description: |
|
||||
Path for receiving front-channel logout requests from the IdP.
|
||||
|
||||
This endpoint receives GET requests with optional sid (session ID) and iss (issuer)
|
||||
query parameters. When called, it invalidates the user's session.
|
||||
|
||||
Example: /frontchannel-logout
|
||||
|
||||
The full URL to configure in your IdP would be:
|
||||
https://your-app.example.com/frontchannel-logout
|
||||
|
||||
Note: This path should be unique and not conflict with your application routes.
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
@@ -1630,3 +1703,234 @@ configuration:
|
||||
|
||||
Default: 30 seconds
|
||||
required: false
|
||||
|
||||
dynamicClientRegistration:
|
||||
type: object
|
||||
description: |
|
||||
Configuration for OIDC Dynamic Client Registration (RFC 7591/7592).
|
||||
|
||||
Dynamic Client Registration allows the middleware to automatically register
|
||||
itself as an OAuth 2.0 client with the OIDC provider, eliminating the need
|
||||
to manually create and manage client credentials.
|
||||
|
||||
This is particularly useful for:
|
||||
- Automated deployments where manual client creation is impractical
|
||||
- Multi-tenant scenarios requiring per-deployment client isolation
|
||||
- Development and testing environments
|
||||
- Kubernetes environments with multiple replicas
|
||||
|
||||
For multi-replica deployments (Kubernetes), enable Redis storage to share
|
||||
credentials across all instances and prevent registration race conditions.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
persistCredentials: true
|
||||
storageBackend: "redis" # Use Redis for distributed storage
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- https://app.example.com/oauth2/callback
|
||||
client_name: "My Application"
|
||||
application_type: "web"
|
||||
```
|
||||
required: false
|
||||
properties:
|
||||
enabled:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable dynamic client registration with the OIDC provider.
|
||||
When enabled and clientID is not set, the middleware will automatically
|
||||
register itself with the provider using the configuration in clientMetadata.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
persistCredentials:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable persistence of client credentials after registration.
|
||||
When enabled, credentials are saved to the configured storage backend
|
||||
and reloaded on restart to avoid re-registration.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
storageBackend:
|
||||
type: string
|
||||
description: |
|
||||
Storage backend for persisting DCR credentials.
|
||||
|
||||
Options:
|
||||
- "file": Store credentials in a local file (default for backward compatibility)
|
||||
- "redis": Store credentials in Redis (recommended for multi-replica deployments)
|
||||
- "auto": Use Redis if available, fall back to file storage
|
||||
|
||||
For Kubernetes deployments with multiple replicas, use "redis" to ensure
|
||||
all instances share the same client credentials and prevent registration
|
||||
race conditions where each replica registers its own client.
|
||||
|
||||
Default: "auto"
|
||||
required: false
|
||||
enum:
|
||||
- file
|
||||
- redis
|
||||
- auto
|
||||
|
||||
credentialsFile:
|
||||
type: string
|
||||
description: |
|
||||
Path to store client credentials when using file-based storage.
|
||||
The file will be created with restrictive permissions (0600).
|
||||
|
||||
Default: "/tmp/oidc-client-credentials.json"
|
||||
required: false
|
||||
|
||||
redisKeyPrefix:
|
||||
type: string
|
||||
description: |
|
||||
Prefix for Redis keys when using Redis storage.
|
||||
Useful for isolating credentials between different applications
|
||||
or environments sharing the same Redis instance.
|
||||
|
||||
Default: "dcr:creds:"
|
||||
required: false
|
||||
|
||||
registrationEndpoint:
|
||||
type: string
|
||||
description: |
|
||||
Override the registration endpoint URL.
|
||||
If not specified, the endpoint will be discovered from provider metadata.
|
||||
|
||||
Some providers may not advertise their registration endpoint in metadata,
|
||||
in which case you need to specify it explicitly.
|
||||
|
||||
Example: "https://auth.example.com/oauth/register"
|
||||
required: false
|
||||
|
||||
initialAccessToken:
|
||||
type: string
|
||||
description: |
|
||||
Initial Access Token for protected registration endpoints.
|
||||
Some providers require an access token to authorize client registration.
|
||||
|
||||
If your provider requires authentication for registration, obtain an
|
||||
initial access token from the provider and configure it here.
|
||||
|
||||
For Kubernetes, you can use secret references:
|
||||
urn:k8s:secret:namespace:secret-name:key
|
||||
required: false
|
||||
|
||||
clientMetadata:
|
||||
type: object
|
||||
description: |
|
||||
Client metadata to include in the registration request (RFC 7591).
|
||||
This defines the properties of the OAuth 2.0 client to be registered.
|
||||
required: false
|
||||
properties:
|
||||
redirect_uris:
|
||||
type: array
|
||||
description: |
|
||||
Array of redirect URIs for the client. Required for registration.
|
||||
These must match the callback URLs that will be used in authentication flows.
|
||||
|
||||
Example: ["https://app.example.com/oauth2/callback"]
|
||||
required: true
|
||||
items:
|
||||
type: string
|
||||
|
||||
client_name:
|
||||
type: string
|
||||
description: |
|
||||
Human-readable name of the client.
|
||||
This is typically displayed in consent screens.
|
||||
|
||||
Example: "My Application"
|
||||
required: false
|
||||
|
||||
application_type:
|
||||
type: string
|
||||
description: |
|
||||
Type of application. Affects security defaults.
|
||||
|
||||
Options:
|
||||
- "web": Server-side web application (default)
|
||||
- "native": Native/mobile application
|
||||
|
||||
Default: "web"
|
||||
required: false
|
||||
|
||||
grant_types:
|
||||
type: array
|
||||
description: |
|
||||
OAuth 2.0 grant types the client will use.
|
||||
|
||||
Default: ["authorization_code", "refresh_token"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
response_types:
|
||||
type: array
|
||||
description: |
|
||||
OAuth 2.0 response types the client will use.
|
||||
|
||||
Default: ["code"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
token_endpoint_auth_method:
|
||||
type: string
|
||||
description: |
|
||||
Authentication method for the token endpoint.
|
||||
|
||||
Options:
|
||||
- "client_secret_basic": HTTP Basic authentication (default)
|
||||
- "client_secret_post": Client credentials in POST body
|
||||
- "none": Public client (no authentication)
|
||||
|
||||
Default: "client_secret_basic"
|
||||
required: false
|
||||
|
||||
scope:
|
||||
type: string
|
||||
description: |
|
||||
Space-separated list of scopes the client is authorized to request.
|
||||
|
||||
Example: "openid profile email"
|
||||
required: false
|
||||
|
||||
contacts:
|
||||
type: array
|
||||
description: |
|
||||
Array of contact email addresses for the client administrator.
|
||||
|
||||
Example: ["admin@example.com"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
logo_uri:
|
||||
type: string
|
||||
description: |
|
||||
URL to the client's logo image for consent screens.
|
||||
required: false
|
||||
|
||||
client_uri:
|
||||
type: string
|
||||
description: |
|
||||
URL to the client's home page.
|
||||
required: false
|
||||
|
||||
policy_uri:
|
||||
type: string
|
||||
description: |
|
||||
URL to the client's privacy policy.
|
||||
required: false
|
||||
|
||||
tos_uri:
|
||||
type: string
|
||||
description: |
|
||||
URL to the client's terms of service.
|
||||
required: false
|
||||
|
||||
@@ -8,7 +8,7 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
|
||||
|
||||
- **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more
|
||||
- **Automatic provider detection**: Automatically detects and configures provider-specific settings
|
||||
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration
|
||||
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration, with Redis storage support for multi-replica deployments
|
||||
- **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes
|
||||
- **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles
|
||||
- **Domain restrictions**: Limit access to specific email domains or individual users
|
||||
@@ -154,6 +154,10 @@ The middleware supports the following configuration options:
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
|
||||
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
|
||||
| `enableBackchannelLogout` | Enable OIDC Back-Channel Logout (IdP-initiated logout via server-to-server POST) | `false` | `true` |
|
||||
| `backchannelLogoutURL` | The path for receiving backchannel logout tokens from the IdP | none | `/backchannel-logout` |
|
||||
| `enableFrontchannelLogout` | Enable OIDC Front-Channel Logout (IdP-initiated logout via iframe) | `false` | `true` |
|
||||
| `frontchannelLogoutURL` | The path for receiving front-channel logout requests from the IdP | none | `/frontchannel-logout` |
|
||||
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
@@ -1148,6 +1152,50 @@ spec:
|
||||
- roles # Appended to defaults: ["openid", "profile", "email", "roles"]
|
||||
```
|
||||
|
||||
### With IdP-Initiated Logout (Backchannel & Front-Channel)
|
||||
|
||||
This plugin supports [OIDC Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) and [OIDC Front-Channel Logout](https://openid.net/specs/openid-connect-frontchannel-1_0.html) for IdP-initiated single logout.
|
||||
|
||||
**Backchannel Logout** (recommended): The IdP sends a server-to-server POST request with a signed `logout_token` JWT when a user logs out.
|
||||
|
||||
**Front-Channel Logout**: The IdP loads an iframe with the logout URL to invalidate the session in the browser.
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-idp-logout
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout # RP-initiated logout
|
||||
|
||||
# Backchannel Logout (server-to-server)
|
||||
enableBackchannelLogout: true
|
||||
backchannelLogoutURL: /backchannel-logout
|
||||
|
||||
# Front-Channel Logout (iframe-based)
|
||||
enableFrontchannelLogout: true
|
||||
frontchannelLogoutURL: /frontchannel-logout
|
||||
|
||||
# For multi-replica deployments, use Redis to share session invalidations
|
||||
redis:
|
||||
enabled: true
|
||||
address: redis:6379
|
||||
```
|
||||
|
||||
> **Note**: For multi-replica deployments, you **must** enable Redis to share session invalidation state across all instances. Otherwise, a logout on one instance won't invalidate sessions on other instances.
|
||||
|
||||
**IdP Configuration**: Configure your IdP to send logout requests to:
|
||||
- **Backchannel**: `https://your-app.example.com/backchannel-logout` (POST with `logout_token`)
|
||||
- **Front-Channel**: `https://your-app.example.com/frontchannel-logout?sid=SESSION_ID&iss=ISSUER` (GET in iframe)
|
||||
|
||||
### With Templated Headers
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
# Security Fix: Integer Overflow Protection in Cache Serialization
|
||||
|
||||
## Summary
|
||||
|
||||
Fixed **High severity** integer overflow vulnerability identified by GitHub Advanced Security in PR #117.
|
||||
|
||||
## Vulnerability
|
||||
|
||||
**Locations**: `universal_cache.go` lines 789 and 811
|
||||
- `result := make([]byte, len(bytes)+1)` - Raw bytes path
|
||||
- `result := make([]byte, len(jsonData)+1)` - JSON encoding path
|
||||
|
||||
**Risk**: Potential integer overflow when allocating memory for very large cache entries.
|
||||
|
||||
## Fix Applied
|
||||
|
||||
1. **Added size limit constant**:
|
||||
```go
|
||||
maxCacheEntrySize = 64 * 1024 * 1024 // 64 MiB
|
||||
```
|
||||
|
||||
2. **Size validation before allocation**:
|
||||
- Validates entry size doesn't exceed limit
|
||||
- Validates adding marker byte won't overflow
|
||||
- Returns descriptive error messages
|
||||
|
||||
3. **Comprehensive test coverage**:
|
||||
- Oversized byte slices (>64 MiB)
|
||||
- Exact max size edge case
|
||||
- Safe sizes (normal operation)
|
||||
- Large JSON data structures
|
||||
|
||||
## Verification
|
||||
|
||||
✅ All tests pass with race detection
|
||||
✅ No security issues (golangci-lint, gosec)
|
||||
✅ 76.3% test coverage maintained
|
||||
|
||||
## Impact
|
||||
|
||||
- No breaking changes
|
||||
- Negligible performance overhead
|
||||
- Prevents potential buffer overflows
|
||||
- Predictable memory usage
|
||||
|
||||
---
|
||||
|
||||
**Date**: January 8, 2026
|
||||
**Severity**: High → Resolved
|
||||
@@ -104,6 +104,14 @@ func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedSessionInvalidationCache returns the shared session invalidation cache
|
||||
// for backchannel and front-channel logout (IdP-initiated logout)
|
||||
func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
func (cm *CacheManager) Close() error {
|
||||
cm.mu.Lock()
|
||||
|
||||
@@ -0,0 +1,290 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/dcrstorage"
|
||||
)
|
||||
|
||||
// DCRStorageBackend represents the type of storage backend for DCR credentials.
|
||||
// Alias for internal package type for backward compatibility.
|
||||
type DCRStorageBackend = dcrstorage.StorageBackend
|
||||
|
||||
const (
|
||||
// DCRStorageBackendFile uses file-based storage (default for backward compatibility)
|
||||
DCRStorageBackendFile DCRStorageBackend = dcrstorage.StorageBackendFile
|
||||
|
||||
// DCRStorageBackendRedis uses Redis for distributed storage
|
||||
DCRStorageBackendRedis DCRStorageBackend = dcrstorage.StorageBackendRedis
|
||||
|
||||
// DCRStorageBackendAuto automatically selects Redis if available, otherwise file
|
||||
DCRStorageBackendAuto DCRStorageBackend = dcrstorage.StorageBackendAuto
|
||||
)
|
||||
|
||||
// DCRCredentialsStore defines the interface for storing DCR credentials.
|
||||
// This abstraction allows different storage backends (file, Redis) to be used
|
||||
// for persisting OIDC Dynamic Client Registration credentials across nodes.
|
||||
type DCRCredentialsStore interface {
|
||||
// Save stores the client registration response for a provider
|
||||
// The providerURL is used as a key to support multi-tenant scenarios
|
||||
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
|
||||
|
||||
// Load retrieves stored credentials for a provider
|
||||
// Returns nil, nil if no credentials exist (not an error)
|
||||
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
|
||||
|
||||
// Delete removes stored credentials for a provider
|
||||
Delete(ctx context.Context, providerURL string) error
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
Exists(ctx context.Context, providerURL string) (bool, error)
|
||||
}
|
||||
|
||||
// loggerAdapter adapts our Logger to the dcrstorage.Logger interface
|
||||
type loggerAdapter struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
func (l *loggerAdapter) Debug(msg string) { l.logger.Debug("%s", msg) }
|
||||
func (l *loggerAdapter) Debugf(format string, args ...any) { l.logger.Debugf(format, args...) }
|
||||
func (l *loggerAdapter) Info(msg string) { l.logger.Info("%s", msg) }
|
||||
func (l *loggerAdapter) Infof(format string, args ...any) { l.logger.Infof(format, args...) }
|
||||
func (l *loggerAdapter) Error(msg string) { l.logger.Error("%s", msg) }
|
||||
func (l *loggerAdapter) Errorf(format string, args ...any) { l.logger.Errorf(format, args...) }
|
||||
|
||||
// cacheAdapter adapts UniversalCache to dcrstorage.Cache interface
|
||||
type cacheAdapter struct {
|
||||
cache *UniversalCache
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Get(key string) (any, bool) {
|
||||
return c.cache.Get(key)
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Set(key string, value any, ttl time.Duration) error {
|
||||
return c.cache.Set(key, value, ttl)
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Delete(key string) {
|
||||
c.cache.Delete(key)
|
||||
}
|
||||
|
||||
// fileStoreWrapper wraps dcrstorage.FileStore to implement DCRCredentialsStore
|
||||
type fileStoreWrapper struct {
|
||||
inner *dcrstorage.FileStore
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
innerCreds := convertCredsToInternal(creds)
|
||||
return w.inner.Save(ctx, providerURL, innerCreds)
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
innerCreds, err := w.inner.Load(ctx, providerURL)
|
||||
if err != nil || innerCreds == nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertCredsFromInternal(innerCreds), nil
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Delete(ctx context.Context, providerURL string) error {
|
||||
return w.inner.Delete(ctx, providerURL)
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
return w.inner.Exists(ctx, providerURL)
|
||||
}
|
||||
|
||||
// basePath returns the base path used for storing credentials (for backward compatibility in tests)
|
||||
func (w *fileStoreWrapper) basePath() string {
|
||||
return w.inner.BasePath()
|
||||
}
|
||||
|
||||
// getFilePath returns the file path for storing credentials for a specific provider (for backward compatibility in tests)
|
||||
func (w *fileStoreWrapper) getFilePath(providerURL string) string {
|
||||
return w.inner.GetFilePath(providerURL)
|
||||
}
|
||||
|
||||
// redisStoreWrapper wraps dcrstorage.RedisStore to implement DCRCredentialsStore
|
||||
type redisStoreWrapper struct {
|
||||
inner *dcrstorage.RedisStore
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
innerCreds := convertCredsToInternal(creds)
|
||||
return w.inner.Save(ctx, providerURL, innerCreds)
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
innerCreds, err := w.inner.Load(ctx, providerURL)
|
||||
if err != nil || innerCreds == nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertCredsFromInternal(innerCreds), nil
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Delete(ctx context.Context, providerURL string) error {
|
||||
return w.inner.Delete(ctx, providerURL)
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
return w.inner.Exists(ctx, providerURL)
|
||||
}
|
||||
|
||||
// FileCredentialsStore implements DCRCredentialsStore using file-based storage.
|
||||
// This is the default storage backend for backward compatibility with existing deployments.
|
||||
type FileCredentialsStore = fileStoreWrapper
|
||||
|
||||
// RedisCredentialsStore implements DCRCredentialsStore using Redis-backed cache.
|
||||
// This storage backend enables sharing DCR credentials across multiple Traefik instances.
|
||||
type RedisCredentialsStore = redisStoreWrapper
|
||||
|
||||
// NewFileCredentialsStore creates a new file-based credentials store.
|
||||
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
|
||||
func NewFileCredentialsStore(basePath string, logger *Logger) *FileCredentialsStore {
|
||||
var dcrLogger dcrstorage.Logger
|
||||
if logger != nil {
|
||||
dcrLogger = &loggerAdapter{logger: logger}
|
||||
}
|
||||
inner := dcrstorage.NewFileStore(basePath, dcrLogger)
|
||||
return &fileStoreWrapper{inner: inner}
|
||||
}
|
||||
|
||||
// NewRedisCredentialsStore creates a new Redis-backed credentials store.
|
||||
// The cache should be configured with a Redis backend for distributed storage.
|
||||
// If keyPrefix is empty, defaults to "dcr:creds:"
|
||||
func NewRedisCredentialsStore(cache *UniversalCache, keyPrefix string, logger *Logger) *RedisCredentialsStore {
|
||||
var dcrLogger dcrstorage.Logger
|
||||
if logger != nil {
|
||||
dcrLogger = &loggerAdapter{logger: logger}
|
||||
}
|
||||
cacheAdapt := &cacheAdapter{cache: cache}
|
||||
inner := dcrstorage.NewRedisStore(cacheAdapt, keyPrefix, dcrLogger)
|
||||
return &redisStoreWrapper{inner: inner}
|
||||
}
|
||||
|
||||
// Helper functions to convert between main package and internal package types
|
||||
func convertCredsToInternal(creds *ClientRegistrationResponse) *dcrstorage.ClientRegistrationResponse {
|
||||
if creds == nil {
|
||||
return nil
|
||||
}
|
||||
return &dcrstorage.ClientRegistrationResponse{
|
||||
SubjectType: creds.SubjectType,
|
||||
LogoURI: creds.LogoURI,
|
||||
RegistrationAccessToken: creds.RegistrationAccessToken,
|
||||
RegistrationClientURI: creds.RegistrationClientURI,
|
||||
Scope: creds.Scope,
|
||||
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
|
||||
TOSURI: creds.TOSURI,
|
||||
PolicyURI: creds.PolicyURI,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
ApplicationType: creds.ApplicationType,
|
||||
ClientID: creds.ClientID,
|
||||
ClientName: creds.ClientName,
|
||||
JWKSURI: creds.JWKSURI,
|
||||
ClientURI: creds.ClientURI,
|
||||
Contacts: creds.Contacts,
|
||||
GrantTypes: creds.GrantTypes,
|
||||
ResponseTypes: creds.ResponseTypes,
|
||||
RedirectURIs: creds.RedirectURIs,
|
||||
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
|
||||
ClientIDIssuedAt: creds.ClientIDIssuedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func convertCredsFromInternal(creds *dcrstorage.ClientRegistrationResponse) *ClientRegistrationResponse {
|
||||
if creds == nil {
|
||||
return nil
|
||||
}
|
||||
return &ClientRegistrationResponse{
|
||||
SubjectType: creds.SubjectType,
|
||||
LogoURI: creds.LogoURI,
|
||||
RegistrationAccessToken: creds.RegistrationAccessToken,
|
||||
RegistrationClientURI: creds.RegistrationClientURI,
|
||||
Scope: creds.Scope,
|
||||
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
|
||||
TOSURI: creds.TOSURI,
|
||||
PolicyURI: creds.PolicyURI,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
ApplicationType: creds.ApplicationType,
|
||||
ClientID: creds.ClientID,
|
||||
ClientName: creds.ClientName,
|
||||
JWKSURI: creds.JWKSURI,
|
||||
ClientURI: creds.ClientURI,
|
||||
Contacts: creds.Contacts,
|
||||
GrantTypes: creds.GrantTypes,
|
||||
ResponseTypes: creds.ResponseTypes,
|
||||
RedirectURIs: creds.RedirectURIs,
|
||||
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
|
||||
ClientIDIssuedAt: creds.ClientIDIssuedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDCRCredentialsStore creates a DCRCredentialsStore based on configuration.
|
||||
// This factory function handles backend selection logic:
|
||||
// - "file": Use file-based storage (default for backward compatibility)
|
||||
// - "redis": Use Redis exclusively (fails if Redis unavailable)
|
||||
// - "auto": Use Redis if available, fallback to file
|
||||
func NewDCRCredentialsStore(
|
||||
config *DynamicClientRegistrationConfig,
|
||||
cacheManager *CacheManager,
|
||||
logger *Logger,
|
||||
) (DCRCredentialsStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("DCR config is nil")
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
backend := config.StorageBackend
|
||||
if backend == "" {
|
||||
backend = string(DCRStorageBackendAuto) // Default to auto selection
|
||||
}
|
||||
|
||||
switch DCRStorageBackend(backend) {
|
||||
case DCRStorageBackendFile:
|
||||
logger.Info("Using file-based storage for DCR credentials")
|
||||
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
|
||||
|
||||
case DCRStorageBackendRedis:
|
||||
cache := getDCRCache(cacheManager)
|
||||
if cache == nil {
|
||||
return nil, fmt.Errorf("redis storage requested but Redis/cache not configured")
|
||||
}
|
||||
logger.Info("Using Redis storage for DCR credentials")
|
||||
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
|
||||
|
||||
case DCRStorageBackendAuto:
|
||||
// Try Redis first, fallback to file
|
||||
cache := getDCRCache(cacheManager)
|
||||
if cache != nil && cache.backend != nil {
|
||||
logger.Info("Auto-selected Redis storage for DCR credentials")
|
||||
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
|
||||
}
|
||||
logger.Info("Redis not available, using file storage for DCR credentials")
|
||||
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown DCR storage backend: %s", backend)
|
||||
}
|
||||
}
|
||||
|
||||
// getDCRCache safely retrieves the DCR credentials cache from the cache manager
|
||||
func getDCRCache(cacheManager *CacheManager) *UniversalCache {
|
||||
if cacheManager == nil {
|
||||
return nil
|
||||
}
|
||||
cacheManager.mu.RLock()
|
||||
defer cacheManager.mu.RUnlock()
|
||||
|
||||
if cacheManager.manager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cacheManager.manager.GetDCRCredentialsCache()
|
||||
}
|
||||
@@ -0,0 +1,663 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestFileCredentialsStore_SaveLoad tests the file-based credentials store
|
||||
func TestFileCredentialsStore_SaveLoad(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp directory for test files
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "test-access-token",
|
||||
RegistrationClientURI: "https://example.com/register/test-client-id",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||||
ResponseTypes: []string{"code"},
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
// Save credentials
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
// Load credentials
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
|
||||
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
tempDir2 := t.TempDir()
|
||||
store2 := NewFileCredentialsStore(filepath.Join(tempDir2, "nonexistent.json"), logger)
|
||||
|
||||
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent file: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
|
||||
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Error("Expected credentials to not exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete non-existent credentials", func(t *testing.T) {
|
||||
// Should not error
|
||||
err := store.Delete(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete should not error for non-existent: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_MultiProvider tests multi-provider support
|
||||
func TestFileCredentialsStore_MultiProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
provider1 := "https://auth1.example.com"
|
||||
provider2 := "https://auth2.example.com"
|
||||
|
||||
creds1 := &ClientRegistrationResponse{
|
||||
ClientID: "client-1",
|
||||
ClientSecret: "secret-1",
|
||||
}
|
||||
creds2 := &ClientRegistrationResponse{
|
||||
ClientID: "client-2",
|
||||
ClientSecret: "secret-2",
|
||||
}
|
||||
|
||||
// Save credentials for both providers
|
||||
if err := store.Save(ctx, provider1, creds1); err != nil {
|
||||
t.Fatalf("Failed to save creds1: %v", err)
|
||||
}
|
||||
if err := store.Save(ctx, provider2, creds2); err != nil {
|
||||
t.Fatalf("Failed to save creds2: %v", err)
|
||||
}
|
||||
|
||||
// Load and verify each provider's credentials
|
||||
loaded1, err := store.Load(ctx, provider1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds1: %v", err)
|
||||
}
|
||||
if loaded1.ClientID != "client-1" {
|
||||
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
|
||||
}
|
||||
|
||||
loaded2, err := store.Load(ctx, provider2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds2: %v", err)
|
||||
}
|
||||
if loaded2.ClientID != "client-2" {
|
||||
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
|
||||
}
|
||||
|
||||
// Delete one shouldn't affect the other
|
||||
if err := store.Delete(ctx, provider1); err != nil {
|
||||
t.Fatalf("Failed to delete creds1: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, provider2)
|
||||
if !exists {
|
||||
t.Error("Provider 2 credentials should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_ConcurrentAccess tests thread safety
|
||||
func TestFileCredentialsStore_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
concurrency := 10
|
||||
|
||||
// Concurrent saves
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = store.Save(ctx, providerURL, creds)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Concurrent loads
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = store.Load(ctx, providerURL)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Final verification
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after concurrent access: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test-client" {
|
||||
t.Error("Credentials corrupted after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_InvalidInput tests error handling
|
||||
func TestFileCredentialsStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty provider URL uses default path", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
err := store.Save(ctx, "", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Save with empty provider URL failed: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Load with empty provider URL failed: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials with empty provider URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_DefaultPath tests default path behavior
|
||||
func TestFileCredentialsStore_DefaultPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore("", logger)
|
||||
|
||||
// Just verify we can create with empty path and it has a default
|
||||
if store.basePath() == "" {
|
||||
t.Error("Expected default base path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_WithMemoryCache tests Redis store with in-memory cache
|
||||
func TestRedisCredentialsStore_WithMemoryCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create an in-memory cache for testing
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "redis-test-client",
|
||||
ClientSecret: "redis-test-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "redis-test-token",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
}
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_TTLFromExpiry tests TTL calculation
|
||||
func TestRedisCredentialsStore_TTLFromExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("expired credentials should fail", func(t *testing.T) {
|
||||
expiredCreds := &ClientRegistrationResponse{
|
||||
ClientID: "expired-client",
|
||||
ClientSecret: "expired-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), // Already expired
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "no-expiry-client",
|
||||
ClientSecret: "no-expiry-secret",
|
||||
ClientSecretExpiresAt: 0, // No expiry
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://noexpiry.example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials without expiry: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_InvalidInput tests error handling
|
||||
func TestRedisCredentialsStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDCRStorageFactory tests the factory function
|
||||
func TestDCRStorageFactory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("nil config returns error", func(t *testing.T) {
|
||||
_, err := NewDCRCredentialsStore(nil, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("file backend creates file store", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "file",
|
||||
CredentialsFile: "/tmp/test-creds.json",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create file store: %v", err)
|
||||
}
|
||||
if store == nil {
|
||||
t.Error("Expected store but got nil")
|
||||
}
|
||||
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("redis backend without cache manager returns error", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "redis",
|
||||
}
|
||||
|
||||
_, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for redis backend without cache manager")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auto backend without redis falls back to file", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "auto",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auto store: %v", err)
|
||||
}
|
||||
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore for auto without redis")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown backend returns error", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "unknown",
|
||||
}
|
||||
|
||||
_, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for unknown backend")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty backend defaults to auto", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create store with empty backend: %v", err)
|
||||
}
|
||||
|
||||
// Should default to file (auto without redis)
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore for empty backend")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDynamicClientRegistrar_WithStore tests registrar with store
|
||||
func TestDynamicClientRegistrar_WithStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrarWithStore(
|
||||
nil, // httpClient
|
||||
logger,
|
||||
config,
|
||||
"https://auth.example.com",
|
||||
store,
|
||||
)
|
||||
|
||||
if registrar == nil {
|
||||
t.Fatal("Expected registrar but got nil")
|
||||
}
|
||||
|
||||
if registrar.store == nil {
|
||||
t.Error("Expected store to be set")
|
||||
}
|
||||
|
||||
// Test SetStore
|
||||
newStore := NewFileCredentialsStore(filepath.Join(tempDir, "new.json"), logger)
|
||||
registrar.SetStore(newStore)
|
||||
|
||||
if registrar.store != newStore {
|
||||
t.Error("SetStore did not update the store")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicClientRegistrar_CredentialsFromStore tests loading from store
|
||||
func TestDynamicClientRegistrar_CredentialsFromStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
providerURL := "https://auth.example.com"
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-save credentials
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "pre-saved-client",
|
||||
ClientSecret: "pre-saved-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
}
|
||||
if err := store.Save(ctx, providerURL, testCreds); err != nil {
|
||||
t.Fatalf("Failed to pre-save credentials: %v", err)
|
||||
}
|
||||
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrarWithStore(
|
||||
nil,
|
||||
logger,
|
||||
config,
|
||||
providerURL,
|
||||
store,
|
||||
)
|
||||
|
||||
// Test loading via the internal method
|
||||
loaded, err := registrar.loadCredentialsFromStore(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load from store: %v", err)
|
||||
}
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != "pre-saved-client" {
|
||||
t.Errorf("ClientID mismatch: got %s", loaded.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_CorruptedFile tests handling of corrupted files
|
||||
func TestFileCredentialsStore_CorruptedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
// Write corrupted JSON
|
||||
filePath := store.getFilePath(providerURL)
|
||||
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write corrupted file: %v", err)
|
||||
}
|
||||
|
||||
// Should return error for corrupted file
|
||||
_, err := store.Load(ctx, providerURL)
|
||||
if err == nil {
|
||||
t.Error("Expected error for corrupted JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_DirectoryCreation tests auto directory creation
|
||||
func TestFileCredentialsStore_DirectoryCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(deepPath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
|
||||
err := store.Save(ctx, "https://example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save with nested directory: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after nested directory creation: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials from nested directory")
|
||||
}
|
||||
}
|
||||
+34
-1
@@ -384,10 +384,14 @@ scopes:
|
||||
|
||||
### Dynamic Client Registration (RFC 7591)
|
||||
|
||||
Dynamic Client Registration allows the middleware to automatically register itself with the OIDC provider, eliminating the need to manually create client credentials.
|
||||
|
||||
**Basic Configuration (Single Instance):**
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
initialAccessToken: "your-token" # Optional
|
||||
initialAccessToken: "your-token" # Optional, if provider requires it
|
||||
persistCredentials: true
|
||||
credentialsFile: "/tmp/oidc-credentials.json"
|
||||
clientMetadata:
|
||||
@@ -400,6 +404,35 @@ dynamicClientRegistration:
|
||||
- "refresh_token"
|
||||
```
|
||||
|
||||
**Multi-Replica Deployment (Kubernetes):**
|
||||
|
||||
For Kubernetes deployments with multiple replicas, use Redis storage to share credentials across all instances and prevent registration race conditions:
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
persistCredentials: true
|
||||
storageBackend: "redis" # Share credentials via Redis
|
||||
redisKeyPrefix: "myapp:dcr:" # Optional custom prefix
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://your-app.com/oauth2/callback"
|
||||
client_name: "My Application"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "redis"
|
||||
```
|
||||
|
||||
**Storage Backend Options:**
|
||||
|
||||
| Backend | Description | Use Case |
|
||||
|---------|-------------|----------|
|
||||
| `file` | Store credentials in local file | Single instance deployments |
|
||||
| `redis` | Store credentials in Redis | Multi-replica Kubernetes deployments |
|
||||
| `auto` | Use Redis if available, fallback to file | Flexible deployments (default) |
|
||||
|
||||
### Multi-Replica Deployment
|
||||
|
||||
Without Redis, disable replay detection:
|
||||
|
||||
+110
-1
@@ -90,6 +90,7 @@
|
||||
<a href="#configuration" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Configuration</a>
|
||||
<a href="#deployment" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Deployment</a>
|
||||
<a href="#security" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Security</a>
|
||||
<a href="#logout" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Logout</a>
|
||||
</div>
|
||||
<div class="flex items-center space-x-4">
|
||||
<button id="theme-toggle" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 p-2 min-w-[44px] min-h-[44px] flex items-center justify-center" aria-label="Toggle theme">
|
||||
@@ -114,6 +115,7 @@
|
||||
<a href="#configuration" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Configuration</a>
|
||||
<a href="#deployment" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Deployment</a>
|
||||
<a href="#security" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Security</a>
|
||||
<a href="#logout" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Logout</a>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
@@ -193,7 +195,7 @@
|
||||
</div>
|
||||
<div>
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-1">Dynamic Registration</h3>
|
||||
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration for automatic client setup without manual configuration</p>
|
||||
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration with Redis storage support for multi-replica deployments</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -862,6 +864,48 @@ spec:
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Dynamic Client Registration (RFC 7591)</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Automatically register your application with the OIDC provider. Supports Redis storage for multi-replica deployments:</p>
|
||||
<div class="overflow-x-auto mb-4">
|
||||
<table class="w-full text-sm">
|
||||
<thead>
|
||||
<tr class="border-b border-gray-200 dark:border-gray-700">
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Parameter</th>
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Default</th>
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Description</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="text-gray-600 dark:text-gray-400">
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.enabled</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable dynamic client registration</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.persistCredentials</code></td>
|
||||
<td class="py-2 px-3">true</td>
|
||||
<td class="py-2 px-3">Persist registered credentials across restarts</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.storageBackend</code></td>
|
||||
<td class="py-2 px-3">auto</td>
|
||||
<td class="py-2 px-3">Storage backend: <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">file</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis</code>, or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">auto</code> (uses Redis if available)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.redisKeyPrefix</code></td>
|
||||
<td class="py-2 px-3">dcr:creds:</td>
|
||||
<td class="py-2 px-3">Redis key prefix for DCR credentials</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.clientMetadata.redirect_uris</code></td>
|
||||
<td class="py-2 px-3">-</td>
|
||||
<td class="py-2 px-3">Redirect URIs for the registered client (required)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Example: Security Headers with CORS</h3>
|
||||
|
||||
@@ -1177,6 +1221,71 @@ spec:
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- IdP-Initiated Logout Section -->
|
||||
<section id="logout" class="py-12 sm:py-16 md:py-20 bg-white dark:bg-gray-900 theme-transition">
|
||||
<div class="max-w-6xl mx-auto px-4 sm:px-6">
|
||||
<div class="text-center mb-8 sm:mb-12">
|
||||
<h2 class="text-2xl sm:text-3xl md:text-4xl font-bold text-gray-900 dark:text-gray-100 mb-3 sm:mb-4">IdP-Initiated Logout</h2>
|
||||
<p class="text-base sm:text-lg text-gray-600 dark:text-gray-300 px-4">Support for OIDC Back-Channel and Front-Channel Logout specifications</p>
|
||||
</div>
|
||||
<div class="max-w-4xl mx-auto">
|
||||
<div class="grid md:grid-cols-2 gap-6 mb-8">
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
|
||||
<i class="fas fa-server mr-2 text-blue-500"></i>
|
||||
Back-Channel Logout
|
||||
</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
|
||||
Server-to-server logout notification. The IdP sends a signed JWT (logout_token) directly to your application when a user logs out.
|
||||
</p>
|
||||
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
|
||||
<li>• Signed JWT logout tokens</li>
|
||||
<li>• Session ID (sid) based invalidation</li>
|
||||
<li>• Subject (sub) based invalidation</li>
|
||||
<li>• Works behind firewalls</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
|
||||
<i class="fas fa-browser mr-2 text-purple-500"></i>
|
||||
Front-Channel Logout
|
||||
</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
|
||||
Browser-based logout via iframe. The IdP embeds an iframe pointing to your logout endpoint during user logout.
|
||||
</p>
|
||||
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
|
||||
<li>• Iframe-based session termination</li>
|
||||
<li>• Immediate cookie invalidation</li>
|
||||
<li>• Simple GET request handling</li>
|
||||
<li>• Issuer validation</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Configuration Example</h3>
|
||||
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# ... other OIDC configuration ...
|
||||
|
||||
# Back-Channel Logout (server-to-server)
|
||||
enableBackchannelLogout: true
|
||||
backchannelLogoutURL: "/backchannel-logout"
|
||||
|
||||
# Front-Channel Logout (browser-based)
|
||||
enableFrontchannelLogout: true
|
||||
frontchannelLogoutURL: "/frontchannel-logout"</code></pre>
|
||||
<p class="text-gray-600 dark:text-gray-400 text-sm mt-4">
|
||||
Configure your IdP with the full URLs (e.g., <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">https://your-app.example.com/backchannel-logout</code>).
|
||||
When a user logs out from the IdP, all their sessions across your applications will be invalidated.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Why Choose Section -->
|
||||
<section class="py-12 sm:py-16 md:py-20 bg-gray-50 dark:bg-gray-800 theme-transition">
|
||||
<div class="max-w-6xl mx-auto px-4 sm:px-6">
|
||||
|
||||
@@ -50,6 +50,7 @@ type DynamicClientRegistrar struct {
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
registrationResponse *ClientRegistrationResponse
|
||||
store DCRCredentialsStore // Storage backend for credentials
|
||||
providerURL string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -73,8 +74,37 @@ func NewDynamicClientRegistrar(
|
||||
}
|
||||
}
|
||||
|
||||
// NewDynamicClientRegistrarWithStore creates a new dynamic client registrar with a specific storage backend
|
||||
func NewDynamicClientRegistrarWithStore(
|
||||
httpClient *http.Client,
|
||||
logger *Logger,
|
||||
dcrConfig *DynamicClientRegistrationConfig,
|
||||
providerURL string,
|
||||
store DCRCredentialsStore,
|
||||
) *DynamicClientRegistrar {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
return &DynamicClientRegistrar{
|
||||
httpClient: httpClient,
|
||||
logger: logger,
|
||||
config: dcrConfig,
|
||||
providerURL: providerURL,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// SetStore sets the credentials store for the registrar
|
||||
// This allows setting the store after creation when the cache manager is available
|
||||
func (r *DynamicClientRegistrar) SetStore(store DCRCredentialsStore) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.store = store
|
||||
}
|
||||
|
||||
// RegisterClient performs dynamic client registration with the OIDC provider
|
||||
// It first attempts to load existing credentials from a file if persistence is enabled,
|
||||
// It first attempts to load existing credentials from storage if persistence is enabled,
|
||||
// then registers a new client if no valid credentials exist.
|
||||
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
|
||||
if r.config == nil || !r.config.Enabled {
|
||||
@@ -83,10 +113,13 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
|
||||
|
||||
// Try to load existing credentials if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if resp, err := r.loadCredentials(); err == nil && resp != nil {
|
||||
resp, err := r.loadCredentialsFromStore(ctx)
|
||||
if err != nil {
|
||||
r.logger.Debugf("Failed to load credentials from store: %v", err)
|
||||
} else if resp != nil {
|
||||
// Check if credentials are still valid (not expired)
|
||||
if r.areCredentialsValid(resp) {
|
||||
r.logger.Info("Loaded existing client credentials from file")
|
||||
r.logger.Info("Loaded existing client credentials from storage")
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = resp
|
||||
r.mu.Unlock()
|
||||
@@ -179,7 +212,7 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
|
||||
|
||||
// Persist credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist client credentials: %v", err)
|
||||
// Don't fail registration if persistence fails
|
||||
}
|
||||
@@ -315,7 +348,44 @@ func (r *DynamicClientRegistrar) credentialsFilePath() string {
|
||||
return "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file
|
||||
// loadCredentialsFromStore loads client credentials from the configured storage backend
|
||||
// Falls back to legacy file-based loading if no store is configured
|
||||
func (r *DynamicClientRegistrar) loadCredentialsFromStore(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Load(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based loading
|
||||
return r.loadCredentials()
|
||||
}
|
||||
|
||||
// saveCredentialsToStore persists client credentials to the configured storage backend
|
||||
// Falls back to legacy file-based saving if no store is configured
|
||||
func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, resp *ClientRegistrationResponse) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Save(ctx, r.providerURL, resp)
|
||||
}
|
||||
// Fallback to legacy file-based saving
|
||||
return r.saveCredentials(resp)
|
||||
}
|
||||
|
||||
// deleteCredentialsFromStore removes credentials from the configured storage backend
|
||||
// Falls back to legacy file-based deletion if no store is configured
|
||||
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Delete(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based deletion
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
@@ -333,7 +403,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCredentials loads client credentials from a file
|
||||
// loadCredentials loads client credentials from a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
@@ -420,7 +490,7 @@ func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -527,11 +597,10 @@ func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) e
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials file if persistence is enabled
|
||||
// Remove credentials from storage if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
r.logger.Errorf("Failed to remove credentials file: %v", err)
|
||||
if err := r.deleteCredentialsFromStore(ctx); err != nil {
|
||||
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -336,6 +336,7 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
// and redirects to the provider's logout endpoint or configured post-logout URI.
|
||||
// It handles potential errors during session retrieval or clearing.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
t.logger.Debug("Processing logout request")
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FileStore implements Store using file-based storage.
|
||||
// This is the default storage backend for backward compatibility with existing deployments.
|
||||
// For distributed environments, consider using RedisStore instead.
|
||||
type FileStore struct {
|
||||
basePath string
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFileStore creates a new file-based credentials store.
|
||||
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
|
||||
func NewFileStore(basePath string, logger Logger) *FileStore {
|
||||
if basePath == "" {
|
||||
basePath = "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
if logger == nil {
|
||||
logger = NoOpLogger()
|
||||
}
|
||||
return &FileStore{
|
||||
basePath: basePath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// BasePath returns the base path used for storing credentials
|
||||
func (s *FileStore) BasePath() string {
|
||||
return s.basePath
|
||||
}
|
||||
|
||||
// GetFilePath returns the file path for storing credentials for a specific provider.
|
||||
// For multi-tenant scenarios, each provider gets a separate file based on URL hash.
|
||||
func (s *FileStore) GetFilePath(providerURL string) string {
|
||||
if providerURL == "" {
|
||||
return s.basePath
|
||||
}
|
||||
|
||||
// Hash provider URL for filename safety and uniqueness
|
||||
hash := sha256.Sum256([]byte(providerURL))
|
||||
hashStr := hex.EncodeToString(hash[:8]) // Use first 8 bytes for shorter filename
|
||||
|
||||
ext := filepath.Ext(s.basePath)
|
||||
base := strings.TrimSuffix(s.basePath, ext)
|
||||
if ext == "" {
|
||||
ext = ".json"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s%s", base, hashStr, ext)
|
||||
}
|
||||
|
||||
// Save stores the client registration response to a file
|
||||
func (s *FileStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
if creds == nil {
|
||||
return fmt.Errorf("credentials cannot be nil")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
// Ensure parent directory exists
|
||||
dir := filepath.Dir(filePath)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create credentials directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(creds, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Write with restrictive permissions (owner read/write only)
|
||||
if err := os.WriteFile(filePath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Saved client credentials to %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves stored credentials from a file.
|
||||
// Returns nil, nil if no credentials file exists (not an error).
|
||||
func (s *FileStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via GetFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // No credentials file exists - not an error
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read credentials file: %w", err)
|
||||
}
|
||||
|
||||
var creds ClientRegistrationResponse
|
||||
if err := json.Unmarshal(data, &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Loaded client credentials from %s", filePath)
|
||||
return &creds, nil
|
||||
}
|
||||
|
||||
// Delete removes the credentials file for a provider
|
||||
func (s *FileStore) Delete(ctx context.Context, providerURL string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to delete
|
||||
}
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Deleted client credentials from %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
func (s *FileStore) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
_, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check credentials file: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache defines the interface for cache operations needed by RedisStore.
|
||||
// This allows the main package to provide a cache implementation without
|
||||
// creating circular dependencies.
|
||||
type Cache interface {
|
||||
// Get retrieves a value from the cache
|
||||
Get(key string) (any, bool)
|
||||
// Set stores a value in the cache with a TTL
|
||||
Set(key string, value any, ttl time.Duration) error
|
||||
// Delete removes a value from the cache
|
||||
Delete(key string)
|
||||
}
|
||||
|
||||
// RedisStore implements Store using a Cache-backed storage.
|
||||
// This storage backend enables sharing DCR credentials across multiple Traefik instances
|
||||
// in distributed environments (e.g., Kubernetes with multiple ingress pods).
|
||||
type RedisStore struct {
|
||||
cache Cache
|
||||
keyPrefix string
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new cache-backed credentials store.
|
||||
// The cache should be configured with a Redis backend for distributed storage.
|
||||
// If keyPrefix is empty, defaults to "dcr:creds:"
|
||||
func NewRedisStore(cache Cache, keyPrefix string, logger Logger) *RedisStore {
|
||||
if keyPrefix == "" {
|
||||
keyPrefix = "dcr:creds:"
|
||||
}
|
||||
if logger == nil {
|
||||
logger = NoOpLogger()
|
||||
}
|
||||
return &RedisStore{
|
||||
cache: cache,
|
||||
keyPrefix: keyPrefix,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// makeKey creates a unique cache key for a provider URL.
|
||||
// Uses SHA256 hash of the provider URL for consistent key generation across nodes.
|
||||
func (s *RedisStore) makeKey(providerURL string) string {
|
||||
if providerURL == "" {
|
||||
return s.keyPrefix + "default"
|
||||
}
|
||||
hash := sha256.Sum256([]byte(providerURL))
|
||||
return s.keyPrefix + hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Save stores the client registration response in the cache.
|
||||
// TTL is calculated based on client_secret_expires_at if available.
|
||||
func (s *RedisStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
if creds == nil {
|
||||
return fmt.Errorf("credentials cannot be nil")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
|
||||
// Calculate TTL based on client_secret_expires_at if available
|
||||
ttl := 30 * 24 * time.Hour // Default: 30 days
|
||||
if creds.ClientSecretExpiresAt > 0 {
|
||||
expiresAt := time.Unix(creds.ClientSecretExpiresAt, 0)
|
||||
ttl = time.Until(expiresAt)
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("credentials already expired")
|
||||
}
|
||||
// Add a small buffer to ensure we don't serve expired credentials
|
||||
if ttl > time.Minute {
|
||||
ttl -= time.Minute
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize credentials to JSON for storage
|
||||
data, err := json.Marshal(creds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Store as string in cache (will be serialized by the cache backend)
|
||||
if err := s.cache.Set(key, string(data), ttl); err != nil {
|
||||
return fmt.Errorf("failed to store credentials in cache: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Saved client credentials to cache with key %s (TTL: %v)", key, ttl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves stored credentials from the cache.
|
||||
// Returns nil, nil if no credentials exist (not an error).
|
||||
func (s *RedisStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
|
||||
value, exists := s.cache.Get(key)
|
||||
if !exists {
|
||||
return nil, nil // No credentials stored - not an error
|
||||
}
|
||||
|
||||
// Handle different value types from cache
|
||||
var jsonData string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
jsonData = v
|
||||
case []byte:
|
||||
jsonData = string(v)
|
||||
default:
|
||||
// Try to see if it's already the struct (from local cache)
|
||||
if creds, ok := value.(*ClientRegistrationResponse); ok {
|
||||
return creds, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected credentials type in cache: %T", value)
|
||||
}
|
||||
|
||||
var creds ClientRegistrationResponse
|
||||
if err := json.Unmarshal([]byte(jsonData), &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials from cache: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Loaded client credentials from cache with key %s", key)
|
||||
return &creds, nil
|
||||
}
|
||||
|
||||
// Delete removes stored credentials from the cache
|
||||
func (s *RedisStore) Delete(ctx context.Context, providerURL string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
s.cache.Delete(key)
|
||||
|
||||
s.logger.Debugf("Deleted client credentials from cache with key %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if credentials exist in the cache for a provider
|
||||
func (s *RedisStore) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
_, exists := s.cache.Get(key)
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
// Package dcrstorage provides storage backends for OIDC Dynamic Client Registration credentials.
|
||||
// It supports both file-based and Redis-based storage for persisting client credentials
|
||||
// across application restarts and distributed deployments.
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// StorageBackend represents the type of storage backend for DCR credentials
|
||||
type StorageBackend string
|
||||
|
||||
const (
|
||||
// StorageBackendFile uses file-based storage (default for backward compatibility)
|
||||
StorageBackendFile StorageBackend = "file"
|
||||
|
||||
// StorageBackendRedis uses Redis for distributed storage
|
||||
StorageBackendRedis StorageBackend = "redis"
|
||||
|
||||
// StorageBackendAuto automatically selects Redis if available, otherwise file
|
||||
StorageBackendAuto StorageBackend = "auto"
|
||||
)
|
||||
|
||||
// Logger interface for DCR storage operations
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...any)
|
||||
Info(msg string)
|
||||
Infof(format string, args ...any)
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
|
||||
type ClientRegistrationResponse struct {
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
}
|
||||
|
||||
// Store defines the interface for storing DCR credentials.
|
||||
// This abstraction allows different storage backends (file, Redis) to be used
|
||||
// for persisting OIDC Dynamic Client Registration credentials across nodes.
|
||||
type Store interface {
|
||||
// Save stores the client registration response for a provider
|
||||
// The providerURL is used as a key to support multi-tenant scenarios
|
||||
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
|
||||
|
||||
// Load retrieves stored credentials for a provider
|
||||
// Returns nil, nil if no credentials exist (not an error)
|
||||
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
|
||||
|
||||
// Delete removes stored credentials for a provider
|
||||
Delete(ctx context.Context, providerURL string) error
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
Exists(ctx context.Context, providerURL string) (bool, error)
|
||||
}
|
||||
|
||||
// noOpLogger is a no-op implementation of Logger for default use
|
||||
type noOpLogger struct{}
|
||||
|
||||
func (n noOpLogger) Debug(msg string) {}
|
||||
func (n noOpLogger) Debugf(format string, args ...any) {}
|
||||
func (n noOpLogger) Info(msg string) {}
|
||||
func (n noOpLogger) Infof(format string, args ...any) {}
|
||||
func (n noOpLogger) Error(msg string) {}
|
||||
func (n noOpLogger) Errorf(format string, args ...any) {}
|
||||
|
||||
// NoOpLogger returns a no-op logger instance
|
||||
func NoOpLogger() Logger {
|
||||
return noOpLogger{}
|
||||
}
|
||||
@@ -0,0 +1,464 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockCache implements Cache for testing
|
||||
type mockCache struct {
|
||||
data map[string]cacheEntry
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
value any
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newMockCache() *mockCache {
|
||||
return &mockCache{data: make(map[string]cacheEntry)}
|
||||
}
|
||||
|
||||
func (m *mockCache) Get(key string) (any, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
entry, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
return entry.value, true
|
||||
}
|
||||
|
||||
func (m *mockCache) Set(key string, value any, ttl time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = cacheEntry{
|
||||
value: value,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCache) Delete(key string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
}
|
||||
|
||||
func TestFileStore_SaveLoad(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "test-access-token",
|
||||
RegistrationClientURI: "https://example.com/register/test-client-id",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||||
ResponseTypes: []string{"code"},
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
|
||||
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
tempDir2 := t.TempDir()
|
||||
store2 := NewFileStore(filepath.Join(tempDir2, "nonexistent.json"), nil)
|
||||
|
||||
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent file: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
|
||||
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Error("Expected credentials to not exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete non-existent credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete should not error for non-existent: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_MultiProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
provider1 := "https://auth1.example.com"
|
||||
provider2 := "https://auth2.example.com"
|
||||
|
||||
creds1 := &ClientRegistrationResponse{
|
||||
ClientID: "client-1",
|
||||
ClientSecret: "secret-1",
|
||||
}
|
||||
creds2 := &ClientRegistrationResponse{
|
||||
ClientID: "client-2",
|
||||
ClientSecret: "secret-2",
|
||||
}
|
||||
|
||||
if err := store.Save(ctx, provider1, creds1); err != nil {
|
||||
t.Fatalf("Failed to save creds1: %v", err)
|
||||
}
|
||||
if err := store.Save(ctx, provider2, creds2); err != nil {
|
||||
t.Fatalf("Failed to save creds2: %v", err)
|
||||
}
|
||||
|
||||
loaded1, err := store.Load(ctx, provider1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds1: %v", err)
|
||||
}
|
||||
if loaded1.ClientID != "client-1" {
|
||||
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
|
||||
}
|
||||
|
||||
loaded2, err := store.Load(ctx, provider2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds2: %v", err)
|
||||
}
|
||||
if loaded2.ClientID != "client-2" {
|
||||
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
|
||||
}
|
||||
|
||||
if err := store.Delete(ctx, provider1); err != nil {
|
||||
t.Fatalf("Failed to delete creds1: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, provider2)
|
||||
if !exists {
|
||||
t.Error("Provider 2 credentials should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
concurrency := 10
|
||||
|
||||
for range concurrency {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = store.Save(ctx, providerURL, creds)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for range concurrency {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = store.Load(ctx, providerURL)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after concurrent access: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test-client" {
|
||||
t.Error("Credentials corrupted after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty provider URL uses default path", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
err := store.Save(ctx, "", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Save with empty provider URL failed: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Load with empty provider URL failed: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials with empty provider URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_DefaultPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := NewFileStore("", nil)
|
||||
|
||||
if store.BasePath() == "" {
|
||||
t.Error("Expected default base path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisStore_WithMockCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "redis-test-client",
|
||||
ClientSecret: "redis-test-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "redis-test-token",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
}
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisStore_TTLFromExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("expired credentials should fail", func(t *testing.T) {
|
||||
expiredCreds := &ClientRegistrationResponse{
|
||||
ClientID: "expired-client",
|
||||
ClientSecret: "expired-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "no-expiry-client",
|
||||
ClientSecret: "no-expiry-secret",
|
||||
ClientSecretExpiresAt: 0,
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://noexpiry.example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials without expiry: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_CorruptedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
filePath := store.GetFilePath(providerURL)
|
||||
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write corrupted file: %v", err)
|
||||
}
|
||||
|
||||
_, err := store.Load(ctx, providerURL)
|
||||
if err == nil {
|
||||
t.Error("Expected error for corrupted JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_DirectoryCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
|
||||
store := NewFileStore(deepPath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
|
||||
err := store.Save(ctx, "https://example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save with nested directory: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after nested directory creation: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials from nested directory")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,502 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// This file implements OIDC Backchannel Logout (OpenID Connect Back-Channel Logout 1.0)
|
||||
// and Front-Channel Logout (OpenID Connect Front-Channel Logout 1.0) functionality.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// logoutTokenType is the expected typ claim for logout tokens
|
||||
// #nosec G101 -- This is a JWT type claim value from OIDC spec, not a credential
|
||||
logoutTokenType = "logout+jwt"
|
||||
|
||||
// sessionInvalidationTTL is how long to remember invalidated sessions
|
||||
// Should be at least as long as your session max age
|
||||
sessionInvalidationTTL = 25 * time.Hour
|
||||
)
|
||||
|
||||
// LogoutTokenClaims represents the claims in an OIDC logout token
|
||||
// as defined in OpenID Connect Back-Channel Logout 1.0
|
||||
type LogoutTokenClaims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audience interface{} `json:"aud"` // Can be string or []string
|
||||
IssuedAt int64 `json:"iat"`
|
||||
JTI string `json:"jti"`
|
||||
Events map[string]interface{} `json:"events"`
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"` // Must NOT be present
|
||||
}
|
||||
|
||||
// handleBackchannelLogout processes OIDC Backchannel Logout requests.
|
||||
// It accepts POST requests with a logout_token parameter containing a JWT
|
||||
// that identifies which session(s) to terminate.
|
||||
//
|
||||
// According to OpenID Connect Back-Channel Logout 1.0:
|
||||
// - The logout_token is a JWT signed by the IdP
|
||||
// - It contains either a 'sid' (session ID) or 'sub' (subject) claim to identify the session
|
||||
// - The RP must validate the token and invalidate the matching session(s)
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer
|
||||
// - req: The HTTP request containing the logout_token
|
||||
func (t *TraefikOidc) handleBackchannelLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
t.logger.Debug("Processing backchannel logout request")
|
||||
|
||||
// Backchannel logout must be POST
|
||||
if req.Method != http.MethodPost {
|
||||
t.logger.Errorf("Backchannel logout: invalid method %s, expected POST", req.Method)
|
||||
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse form data to get logout_token
|
||||
if err := req.ParseForm(); err != nil {
|
||||
t.logger.Errorf("Backchannel logout: failed to parse form: %v", err)
|
||||
http.Error(rw, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logoutToken := req.FormValue("logout_token")
|
||||
if logoutToken == "" {
|
||||
// Also try reading from request body as raw JWT
|
||||
body, err := io.ReadAll(io.LimitReader(req.Body, 64*1024)) // 64KB limit
|
||||
if err == nil && len(body) > 0 {
|
||||
logoutToken = string(body)
|
||||
}
|
||||
}
|
||||
|
||||
if logoutToken == "" {
|
||||
t.logger.Error("Backchannel logout: missing logout_token")
|
||||
http.Error(rw, "logout_token required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and validate the logout token
|
||||
claims, err := t.validateLogoutToken(logoutToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Backchannel logout: token validation failed: %v", err)
|
||||
// Return 400 for invalid token per spec
|
||||
http.Error(rw, "Invalid logout token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate session(s) based on sid or sub
|
||||
if err := t.invalidateSession(claims.SessionID, claims.Subject); err != nil {
|
||||
t.logger.Errorf("Backchannel logout: failed to invalidate session: %v", err)
|
||||
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Infof("Backchannel logout: successfully invalidated session (sid=%s, sub=%s)",
|
||||
claims.SessionID, claims.Subject)
|
||||
|
||||
// Return 200 OK with empty body per spec
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// handleFrontchannelLogout processes OIDC Front-Channel Logout requests.
|
||||
// It accepts GET requests with 'iss' and 'sid' query parameters that identify
|
||||
// which session to terminate. The IdP typically loads this URL in an iframe.
|
||||
//
|
||||
// According to OpenID Connect Front-Channel Logout 1.0:
|
||||
// - The request contains 'iss' (issuer) and optionally 'sid' (session ID)
|
||||
// - The RP should clear the session and return a response (typically empty or image)
|
||||
// - The response must be cacheable to allow the IdP to load it in an iframe
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer
|
||||
// - req: The HTTP request containing iss and sid parameters
|
||||
func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
t.logger.Debug("Processing front-channel logout request")
|
||||
|
||||
// Front-channel logout should be GET
|
||||
if req.Method != http.MethodGet {
|
||||
t.logger.Errorf("Front-channel logout: invalid method %s, expected GET", req.Method)
|
||||
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Get iss and sid from query parameters
|
||||
iss := req.URL.Query().Get("iss")
|
||||
sid := req.URL.Query().Get("sid")
|
||||
|
||||
// Validate issuer matches our expected issuer
|
||||
t.metadataMu.RLock()
|
||||
expectedIssuer := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if iss != "" && iss != expectedIssuer {
|
||||
t.logger.Errorf("Front-channel logout: issuer mismatch: got %s, expected %s", iss, expectedIssuer)
|
||||
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Must have at least sid for front-channel logout
|
||||
if sid == "" {
|
||||
t.logger.Error("Front-channel logout: missing sid parameter")
|
||||
http.Error(rw, "sid parameter required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate the session
|
||||
if err := t.invalidateSession(sid, ""); err != nil {
|
||||
t.logger.Errorf("Front-channel logout: failed to invalidate session: %v", err)
|
||||
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Infof("Front-channel logout: successfully invalidated session (sid=%s)", sid)
|
||||
|
||||
// Return a minimal HTML response that's suitable for iframe loading
|
||||
// Set headers to allow embedding and caching
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.Header().Set("Cache-Control", "no-cache, no-store")
|
||||
rw.Header().Set("Pragma", "no-cache")
|
||||
// Allow embedding in iframes from any origin (required for front-channel logout)
|
||||
rw.Header().Del("X-Frame-Options")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write([]byte("<!DOCTYPE html><html><head><title>Logged Out</title></head><body></body></html>"))
|
||||
}
|
||||
|
||||
// validateLogoutToken parses and validates a logout token JWT.
|
||||
// It verifies the token signature, issuer, audience, and required claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT logout token
|
||||
//
|
||||
// Returns:
|
||||
// - The parsed logout token claims
|
||||
// - An error if validation fails
|
||||
func (t *TraefikOidc) validateLogoutToken(tokenString string) (*LogoutTokenClaims, error) {
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(tokenString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse logout token: %w", err)
|
||||
}
|
||||
|
||||
// Check token type if present
|
||||
if typ, ok := jwt.Header["typ"].(string); ok {
|
||||
// The typ should be "logout+jwt" or omitted
|
||||
if typ != "" && typ != logoutTokenType && typ != "JWT" {
|
||||
return nil, fmt.Errorf("invalid token type: %s", typ)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify signature only (not standard claims - logout tokens don't have 'exp')
|
||||
if err := t.verifyLogoutTokenSignature(jwt, tokenString); err != nil {
|
||||
return nil, fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
// Extract claims
|
||||
claims := &LogoutTokenClaims{}
|
||||
claimsJSON, err := json.Marshal(jwt.Claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal claims: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(claimsJSON, claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
|
||||
}
|
||||
|
||||
// Validate required claims
|
||||
t.metadataMu.RLock()
|
||||
expectedIssuer := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
// Validate issuer
|
||||
if claims.Issuer != expectedIssuer {
|
||||
return nil, fmt.Errorf("issuer mismatch: got %s, expected %s", claims.Issuer, expectedIssuer)
|
||||
}
|
||||
|
||||
// Validate audience (must contain our client_id)
|
||||
if !t.validateLogoutTokenAudience(claims.Audience) {
|
||||
return nil, fmt.Errorf("audience validation failed")
|
||||
}
|
||||
|
||||
// Validate iat (issued at) - must be present and not too old
|
||||
if claims.IssuedAt == 0 {
|
||||
return nil, fmt.Errorf("missing iat claim")
|
||||
}
|
||||
iatTime := time.Unix(claims.IssuedAt, 0)
|
||||
// Allow up to 5 minutes clock skew and 10 minutes token age
|
||||
if time.Since(iatTime) > 15*time.Minute {
|
||||
return nil, fmt.Errorf("logout token too old: issued at %v", iatTime)
|
||||
}
|
||||
// Token should not be from the future (with 5 min clock skew tolerance)
|
||||
if iatTime.After(time.Now().Add(5 * time.Minute)) {
|
||||
return nil, fmt.Errorf("logout token issued in the future: %v", iatTime)
|
||||
}
|
||||
|
||||
// Validate events claim - must contain the logout event
|
||||
if claims.Events == nil {
|
||||
return nil, fmt.Errorf("missing events claim")
|
||||
}
|
||||
if _, ok := claims.Events["http://schemas.openid.net/event/backchannel-logout"]; !ok {
|
||||
return nil, fmt.Errorf("missing backchannel-logout event in events claim")
|
||||
}
|
||||
|
||||
// Validate that nonce is NOT present (per spec)
|
||||
if claims.Nonce != "" {
|
||||
return nil, fmt.Errorf("nonce claim must not be present in logout token")
|
||||
}
|
||||
|
||||
// Must have either sid or sub (or both)
|
||||
if claims.SessionID == "" && claims.Subject == "" {
|
||||
return nil, fmt.Errorf("logout token must contain either sid or sub claim")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// validateLogoutTokenAudience checks if the logout token audience contains our client_id
|
||||
func (t *TraefikOidc) validateLogoutTokenAudience(aud interface{}) bool {
|
||||
switch v := aud.(type) {
|
||||
case string:
|
||||
return v == t.clientID
|
||||
case []interface{}:
|
||||
for _, a := range v {
|
||||
if s, ok := a.(string); ok && s == t.clientID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, a := range v {
|
||||
if a == t.clientID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// verifyLogoutTokenSignature verifies only the signature of a logout token.
|
||||
// Unlike VerifyJWTSignatureAndClaims, this does NOT validate standard claims like 'exp'
|
||||
// because logout tokens don't have an expiration claim per OIDC Back-Channel Logout spec.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwt: The parsed JWT structure
|
||||
// - tokenString: The raw token string for signature verification
|
||||
//
|
||||
// Returns:
|
||||
// - An error if signature verification fails
|
||||
func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) error {
|
||||
t.logger.Debug("Verifying logout token signature")
|
||||
|
||||
// Read jwksURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
jwksURL := t.jwksURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
if jwks == nil {
|
||||
return fmt.Errorf("JWKS is nil, cannot verify token")
|
||||
}
|
||||
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
if !ok || kid == "" {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
}
|
||||
|
||||
alg, ok := jwt.Header["alg"].(string)
|
||||
if !ok || alg == "" {
|
||||
return fmt.Errorf("missing algorithm in token header")
|
||||
}
|
||||
|
||||
// Find the matching key in JWKS
|
||||
var matchingKey *JWK
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
matchingKey = &key
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchingKey == nil {
|
||||
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
publicKeyPEM, err := jwkToPEM(matchingKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySignature(tokenString, publicKeyPEM, alg); err != nil {
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
t.logger.Debug("Logout token signature verified successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// invalidateSession marks a session as invalidated in the session invalidation cache.
|
||||
// It stores entries by both sid and sub if available.
|
||||
//
|
||||
// Parameters:
|
||||
// - sid: The session ID to invalidate (from the 'sid' claim)
|
||||
// - sub: The subject to invalidate (from the 'sub' claim)
|
||||
//
|
||||
// Returns:
|
||||
// - An error if the invalidation fails
|
||||
func (t *TraefikOidc) invalidateSession(sid, sub string) error {
|
||||
if t.sessionInvalidationCache == nil {
|
||||
return fmt.Errorf("session invalidation cache not initialized")
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
|
||||
// Store by session ID
|
||||
if sid != "" {
|
||||
key := t.buildSessionInvalidationKey("sid", sid)
|
||||
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
|
||||
t.logger.Debugf("Invalidated session by sid: %s", sid)
|
||||
}
|
||||
|
||||
// Store by subject (invalidates all sessions for this user)
|
||||
if sub != "" {
|
||||
key := t.buildSessionInvalidationKey("sub", sub)
|
||||
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
|
||||
t.logger.Debugf("Invalidated session by sub: %s", sub)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isSessionInvalidated checks if a session has been invalidated via backchannel
|
||||
// or front-channel logout.
|
||||
//
|
||||
// Parameters:
|
||||
// - sid: The session ID to check
|
||||
// - sub: The subject to check
|
||||
// - sessionCreatedAt: When the session was created (to compare against invalidation time)
|
||||
//
|
||||
// Returns:
|
||||
// - true if the session has been invalidated, false otherwise
|
||||
func (t *TraefikOidc) isSessionInvalidated(sid, sub string, sessionCreatedAt time.Time) bool {
|
||||
if t.sessionInvalidationCache == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Truncate session creation time to seconds for fair comparison with Unix timestamps
|
||||
sessionCreatedAtSec := sessionCreatedAt.Truncate(time.Second)
|
||||
|
||||
// Check by session ID first (more specific)
|
||||
if sid != "" {
|
||||
key := t.buildSessionInvalidationKey("sid", sid)
|
||||
if val, found := t.sessionInvalidationCache.Get(key); found {
|
||||
if invalidatedAt, ok := val.(int64); ok {
|
||||
// Session was invalidated at or after it was created
|
||||
invalidationTime := time.Unix(invalidatedAt, 0)
|
||||
if !invalidationTime.Before(sessionCreatedAtSec) {
|
||||
t.logger.Debugf("Session invalidated by sid: %s", sid)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check by subject (all sessions for this user)
|
||||
if sub != "" {
|
||||
key := t.buildSessionInvalidationKey("sub", sub)
|
||||
if val, found := t.sessionInvalidationCache.Get(key); found {
|
||||
if invalidatedAt, ok := val.(int64); ok {
|
||||
// Sessions for this subject created at or before invalidation are invalid
|
||||
invalidationTime := time.Unix(invalidatedAt, 0)
|
||||
if !invalidationTime.Before(sessionCreatedAtSec) {
|
||||
t.logger.Debugf("Session invalidated by sub: %s", sub)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// buildSessionInvalidationKey creates a cache key for session invalidation
|
||||
func (t *TraefikOidc) buildSessionInvalidationKey(keyType, value string) string {
|
||||
return fmt.Sprintf("session_invalidation:%s:%s", keyType, value)
|
||||
}
|
||||
|
||||
// extractSessionInfo extracts sid and sub from an ID token for session tracking
|
||||
func (t *TraefikOidc) extractSessionInfo(idToken string) (sid, sub string, createdAt time.Time) {
|
||||
if idToken == "" {
|
||||
return "", "", time.Time{}
|
||||
}
|
||||
|
||||
jwt, err := parseJWT(idToken)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}
|
||||
}
|
||||
|
||||
// Extract sid (session ID)
|
||||
if sidVal, ok := jwt.Claims["sid"].(string); ok {
|
||||
sid = sidVal
|
||||
}
|
||||
|
||||
// Extract sub (subject)
|
||||
if subVal, ok := jwt.Claims["sub"].(string); ok {
|
||||
sub = subVal
|
||||
}
|
||||
|
||||
// Extract iat for session creation time
|
||||
if iatVal, ok := jwt.Claims["iat"].(float64); ok {
|
||||
createdAt = time.Unix(int64(iatVal), 0)
|
||||
} else {
|
||||
// Default to now if iat not present
|
||||
createdAt = time.Now()
|
||||
}
|
||||
|
||||
return sid, sub, createdAt
|
||||
}
|
||||
|
||||
// determineLogoutPath checks if the given path matches any logout URL
|
||||
func (t *TraefikOidc) determineLogoutPath(path string) string {
|
||||
// Check backchannel logout path
|
||||
if t.backchannelLogoutPath != "" && path == t.backchannelLogoutPath {
|
||||
return "backchannel"
|
||||
}
|
||||
|
||||
// Check front-channel logout path
|
||||
if t.frontchannelLogoutPath != "" && path == t.frontchannelLogoutPath {
|
||||
return "frontchannel"
|
||||
}
|
||||
|
||||
// Check regular logout path (for RP-initiated logout)
|
||||
if path == t.logoutURLPath {
|
||||
return "rp"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// normalizeLogoutPath ensures logout paths start with / and prevents open redirects
|
||||
func normalizeLogoutPath(path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
// Prevent open redirect: ensure second character is not / or \
|
||||
// This prevents URLs like //example.com or /\example.com from being treated as absolute URLs
|
||||
if len(path) > 1 && (path[1] == '/' || path[1] == '\\') {
|
||||
// Strip leading slashes/backslashes and re-normalize
|
||||
path = strings.TrimLeft(path, "/\\")
|
||||
if path != "" {
|
||||
path = "/" + path
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
+1623
File diff suppressed because it is too large
Load Diff
@@ -212,16 +212,21 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return 60 * time.Second
|
||||
}(),
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
cancelFunc: cancelFunc,
|
||||
suppressDiagnosticLogs: isTestMode(),
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
dcrConfig: config.DynamicClientRegistration,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
minimalHeaders: config.MinimalHeaders,
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
cancelFunc: cancelFunc,
|
||||
suppressDiagnosticLogs: isTestMode(),
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
dcrConfig: config.DynamicClientRegistration,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
minimalHeaders: config.MinimalHeaders,
|
||||
enableBackchannelLogout: config.EnableBackchannelLogout,
|
||||
enableFrontchannelLogout: config.EnableFrontchannelLogout,
|
||||
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
|
||||
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
|
||||
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
|
||||
}
|
||||
|
||||
// Log audience configuration
|
||||
@@ -433,6 +438,19 @@ func (t *TraefikOidc) performDynamicClientRegistration() {
|
||||
t.dcrConfig,
|
||||
t.providerURL,
|
||||
)
|
||||
|
||||
// Set up storage backend for credentials persistence
|
||||
if t.dcrConfig.PersistCredentials {
|
||||
cacheManager := GetGlobalCacheManagerWithConfig(t.goroutineWG, nil)
|
||||
store, err := NewDCRCredentialsStore(t.dcrConfig, cacheManager, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to create DCR credentials store: %v", err)
|
||||
// Continue without persistence - registration will still work
|
||||
} else {
|
||||
t.dynamicClientRegistrar.SetStore(store)
|
||||
t.logger.Debugf("DCR credentials store initialized with backend: %s", t.dcrConfig.StorageBackend)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get registration endpoint (from metadata or config override)
|
||||
|
||||
+73
-6
@@ -26,6 +26,31 @@ import (
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The incoming HTTP request.
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Log request entry for debugging routing issues
|
||||
t.logger.Debugf("Incoming request: %s %s", req.Method, req.URL.Path)
|
||||
|
||||
// Handle logout requests early - before waiting for OIDC initialization
|
||||
// This allows users to logout even if the OIDC provider is unavailable
|
||||
if req.URL.Path == t.logoutURLPath {
|
||||
t.logger.Debugf("Logout path matched early: %s", req.URL.Path)
|
||||
t.handleLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle backchannel logout (IdP-initiated POST with logout_token)
|
||||
if t.enableBackchannelLogout && t.backchannelLogoutPath != "" && req.URL.Path == t.backchannelLogoutPath {
|
||||
t.logger.Debug("Backchannel logout path matched")
|
||||
t.handleBackchannelLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle front-channel logout (IdP-initiated GET with sid/iss in iframe)
|
||||
if t.enableFrontchannelLogout && t.frontchannelLogoutPath != "" && req.URL.Path == t.frontchannelLogoutPath {
|
||||
t.logger.Debug("Front-channel logout path matched")
|
||||
t.handleFrontchannelLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.URL.Path, "/health") {
|
||||
t.firstRequestMutex.Lock()
|
||||
if !t.firstRequestReceived {
|
||||
@@ -42,6 +67,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Check excluded URLs before waiting for initialization
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for SSE requests before waiting for initialization
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/event-stream") {
|
||||
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Log waiting for initialization to help diagnose hanging requests
|
||||
t.logger.Debug("Waiting for OIDC provider initialization...")
|
||||
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
// Read issuerURL with RLock
|
||||
@@ -83,13 +126,23 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
acceptHeader = req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/event-stream") {
|
||||
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
|
||||
// Set forwarded user headers from existing session before bypassing
|
||||
if session, err := t.sessionManager.GetSession(req); err == nil {
|
||||
defer session.returnToPoolSafely()
|
||||
if email := session.GetEmail(); email != "" {
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
}
|
||||
t.logger.Debugf("SSE bypass: forwarded user %s from session", email)
|
||||
}
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
t.sessionManager.CleanupOldCookies(rw, req)
|
||||
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
@@ -120,10 +173,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
host := utils.DetermineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
if req.URL.Path == t.logoutURLPath {
|
||||
t.handleLogout(rw, req)
|
||||
return
|
||||
}
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
t.handleCallback(rw, req, redirectURL)
|
||||
return
|
||||
@@ -264,6 +313,24 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
return
|
||||
}
|
||||
|
||||
// Check if session has been invalidated via backchannel or front-channel logout
|
||||
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
|
||||
idToken := session.GetIDToken()
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
|
||||
// Clear the session and redirect to login
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
}
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenForClaims := session.GetIDToken()
|
||||
if tokenForClaims == "" {
|
||||
tokenForClaims = session.GetAccessToken()
|
||||
|
||||
@@ -95,6 +95,38 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogoutWorksWithoutOIDCInitialization tests that logout works even if OIDC provider is unavailable
|
||||
// This is critical for allowing users to clear their session when the provider is down
|
||||
func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
logoutURLPath: "/logout",
|
||||
postLogoutRedirectURI: "/",
|
||||
forceHTTPS: false,
|
||||
}
|
||||
// Note: initComplete is NOT closed, simulating OIDC provider being unavailable
|
||||
|
||||
req := httptest.NewRequest("GET", "/logout", nil)
|
||||
req.Host = "example.com"
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
// Should redirect to post-logout URI even without OIDC initialization
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected redirect (302) for logout, got %d", rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location == "" {
|
||||
t.Error("Expected Location header for logout redirect")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddlewareDomainRestrictions tests domain-based access control
|
||||
// NOTE: Currently commented out due to complex session setup requirements
|
||||
// These scenarios are tested indirectly through integration tests
|
||||
|
||||
+1
-12
@@ -1820,23 +1820,12 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
if token != "" {
|
||||
dotCount := strings.Count(token, ".")
|
||||
// Reject tokens with exactly 1 dot (invalid format - neither JWT nor opaque)
|
||||
if dotCount == 1 {
|
||||
if sd.manager != nil && sd.manager.logger != nil {
|
||||
sd.manager.logger.Debug("Invalid token format during storage (dots: %d) - rejecting", dotCount)
|
||||
}
|
||||
return
|
||||
}
|
||||
// For opaque tokens (no dots), ensure minimum length for security
|
||||
if dotCount == 0 && len(token) < 20 {
|
||||
if len(token) < 20 {
|
||||
if sd.manager != nil && sd.manager.logger != nil {
|
||||
sd.manager.logger.Debug("Token too short for opaque token (length: %d) - rejecting", len(token))
|
||||
}
|
||||
return
|
||||
}
|
||||
// Tokens with 2 dots are JWTs, tokens with 0 dots are opaque
|
||||
// Both are valid formats
|
||||
}
|
||||
|
||||
currentAccessToken := sd.getAccessTokenUnsafe()
|
||||
|
||||
+13
-116
@@ -65,6 +65,10 @@ type Config struct {
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
|
||||
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
|
||||
EnableBackchannelLogout bool `json:"enableBackchannelLogout,omitempty"`
|
||||
EnableFrontchannelLogout bool `json:"enableFrontchannelLogout,omitempty"`
|
||||
BackchannelLogoutURL string `json:"backchannelLogoutURL,omitempty"`
|
||||
FrontchannelLogoutURL string `json:"frontchannelLogoutURL,omitempty"`
|
||||
}
|
||||
|
||||
// RedisConfig configures Redis cache backend settings for distributed caching.
|
||||
@@ -98,8 +102,15 @@ type DynamicClientRegistrationConfig struct {
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
// StorageBackend specifies where to store DCR credentials: "file", "redis", or "auto"
|
||||
// - "file": Use file-based storage (default for backward compatibility)
|
||||
// - "redis": Use Redis exclusively (fails if Redis unavailable)
|
||||
// - "auto": Use Redis if available, fallback to file (default)
|
||||
StorageBackend string `json:"storageBackend,omitempty"`
|
||||
// RedisKeyPrefix is the prefix for Redis keys when using Redis storage (default: "dcr:creds:")
|
||||
RedisKeyPrefix string `json:"redisKeyPrefix,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
}
|
||||
|
||||
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
|
||||
@@ -737,15 +748,6 @@ func newNoOpLogger() *Logger {
|
||||
// - code: The HTTP status code for the response.
|
||||
// - logger: The Logger instance to use for logging the error.
|
||||
//
|
||||
// handleError writes an HTTP error response with the specified status code and message.
|
||||
// It logs the error and sets appropriate headers before writing the response.
|
||||
//
|
||||
//lint:ignore U1000 Kept for potential future error handling
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error("%s", message)
|
||||
http.Error(w, message, code)
|
||||
}
|
||||
|
||||
// GetSecurityHeadersApplier returns a function that applies security headers
|
||||
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
|
||||
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
|
||||
@@ -1051,111 +1053,6 @@ func (rc *RedisConfig) ApplyEnvFallbacks() {
|
||||
}
|
||||
}
|
||||
|
||||
// LoadRedisConfigFromEnv loads Redis configuration from environment variables.
|
||||
// Deprecated: Use RedisConfig.ApplyEnvFallbacks() on an existing config instead.
|
||||
// This function is kept for backward compatibility but should not be used directly.
|
||||
func LoadRedisConfigFromEnv() *RedisConfig {
|
||||
// Check if Redis is enabled
|
||||
enabledStr := os.Getenv("REDIS_ENABLED")
|
||||
if enabledStr == "" || enabledStr == "false" || enabledStr == "0" {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &RedisConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
// Parse numeric values
|
||||
if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
|
||||
if db, err := strconv.Atoi(dbStr); err == nil {
|
||||
config.DB = db
|
||||
}
|
||||
}
|
||||
|
||||
if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
|
||||
if poolSize, err := strconv.Atoi(poolSizeStr); err == nil {
|
||||
config.PoolSize = poolSize
|
||||
}
|
||||
}
|
||||
|
||||
if connectTimeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); connectTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(connectTimeoutStr); err == nil {
|
||||
config.ConnectTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
if readTimeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); readTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(readTimeoutStr); err == nil {
|
||||
config.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
if writeTimeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(writeTimeoutStr); err == nil {
|
||||
config.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Parse boolean values
|
||||
if enableTLSStr := os.Getenv("REDIS_ENABLE_TLS"); enableTLSStr == "true" || enableTLSStr == "1" {
|
||||
config.EnableTLS = true
|
||||
}
|
||||
|
||||
if skipVerifyStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipVerifyStr == "true" || skipVerifyStr == "1" {
|
||||
config.TLSSkipVerify = true
|
||||
}
|
||||
|
||||
// Parse hybrid mode settings
|
||||
if l1SizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); l1SizeStr != "" {
|
||||
if size, err := strconv.Atoi(l1SizeStr); err == nil {
|
||||
config.HybridL1Size = size
|
||||
}
|
||||
}
|
||||
|
||||
if l1MemoryStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); l1MemoryStr != "" {
|
||||
if memory, err := strconv.ParseInt(l1MemoryStr, 10, 64); err == nil {
|
||||
config.HybridL1MemoryMB = memory
|
||||
}
|
||||
}
|
||||
|
||||
// Parse circuit breaker settings
|
||||
if enableCBStr := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCBStr == "false" || enableCBStr == "0" {
|
||||
config.EnableCircuitBreaker = false
|
||||
} else {
|
||||
config.EnableCircuitBreaker = true // Default to enabled
|
||||
}
|
||||
|
||||
if cbThresholdStr := os.Getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD"); cbThresholdStr != "" {
|
||||
if threshold, err := strconv.Atoi(cbThresholdStr); err == nil {
|
||||
config.CircuitBreakerThreshold = threshold
|
||||
}
|
||||
}
|
||||
|
||||
if cbTimeoutStr := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(cbTimeoutStr); err == nil {
|
||||
config.CircuitBreakerTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Parse health check settings
|
||||
if enableHCStr := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHCStr == "false" || enableHCStr == "0" {
|
||||
config.EnableHealthCheck = false
|
||||
} else {
|
||||
config.EnableHealthCheck = true // Default to enabled
|
||||
}
|
||||
|
||||
if hcIntervalStr := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcIntervalStr != "" {
|
||||
if interval, err := strconv.Atoi(hcIntervalStr); err == nil {
|
||||
config.HealthCheckInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
// Apply defaults after loading from env
|
||||
config.ApplyDefaults()
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed || allowed == "*" {
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -251,6 +254,30 @@ func TestSingletonResourceManager(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// createMockOIDCServer creates a mock OIDC server for testing
|
||||
func createMockOIDCServer() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/authorize",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"jwks_uri": "https://example.com/jwks",
|
||||
"userinfo_endpoint": "https://example.com/userinfo",
|
||||
})
|
||||
case "/jwks":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"keys": []interface{}{},
|
||||
})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// TestContextAwareGoroutineManagement tests context-aware goroutine management
|
||||
func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
t.Run("GoroutineCleanupOnContextCancel", func(t *testing.T) {
|
||||
@@ -259,13 +286,17 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
ResetUniversalCacheManagerForTesting()
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Create mock OIDC server
|
||||
mockServer := createMockOIDCServer()
|
||||
defer mockServer.Close()
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create a TraefikOidc instance with context
|
||||
config := &Config{
|
||||
ProviderURL: "https://example.com",
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
@@ -308,12 +339,20 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
ResetUniversalCacheManagerForTesting()
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Create mock OIDC servers
|
||||
mockServer1 := createMockOIDCServer()
|
||||
defer mockServer1.Close()
|
||||
mockServer2 := createMockOIDCServer()
|
||||
defer mockServer2.Close()
|
||||
mockServer3 := createMockOIDCServer()
|
||||
defer mockServer3.Close()
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
configs := []Config{
|
||||
{ProviderURL: "https://example1.com", ClientID: "client1", ClientSecret: "secret1"},
|
||||
{ProviderURL: "https://example2.com", ClientID: "client2", ClientSecret: "secret2"},
|
||||
{ProviderURL: "https://example3.com", ClientID: "client3", ClientSecret: "secret3"},
|
||||
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
|
||||
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
|
||||
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
|
||||
}
|
||||
|
||||
var plugins []*TraefikOidc
|
||||
@@ -366,6 +405,13 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
ResetUniversalCacheManagerForTesting()
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Create mock OIDC servers
|
||||
mockServers := make([]*httptest.Server, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
mockServers[i] = createMockOIDCServer()
|
||||
defer mockServers[i].Close()
|
||||
}
|
||||
|
||||
rm := GetResourceManager()
|
||||
|
||||
// Register singleton cleanup task
|
||||
@@ -386,7 +432,7 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
for i := 0; i < 3; i++ {
|
||||
ctx := context.Background()
|
||||
config := &Config{
|
||||
ProviderURL: fmt.Sprintf("https://example%d.com", i),
|
||||
ProviderURL: mockServers[i].URL,
|
||||
ClientID: fmt.Sprintf("client%d", i),
|
||||
ClientSecret: fmt.Sprintf("secret%d", i),
|
||||
}
|
||||
|
||||
@@ -119,6 +119,8 @@ type TraefikOidc struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
registrationURL string
|
||||
backchannelLogoutPath string
|
||||
frontchannelLogoutPath string
|
||||
scopesSupported []string
|
||||
scopes []string
|
||||
refreshGracePeriod time.Duration
|
||||
@@ -126,7 +128,10 @@ type TraefikOidc struct {
|
||||
shutdownOnce sync.Once
|
||||
metadataRetryMutex sync.Mutex
|
||||
firstRequestMutex sync.Mutex
|
||||
sessionInvalidationCache CacheInterface
|
||||
minimalHeaders bool
|
||||
enableBackchannelLogout bool
|
||||
enableFrontchannelLogout bool
|
||||
firstRequestReceived bool
|
||||
requireTokenIntrospection bool
|
||||
metadataRefreshStarted bool
|
||||
|
||||
+75
-20
@@ -21,6 +21,10 @@ const (
|
||||
CacheTypeJWK CacheType = "jwk"
|
||||
CacheTypeSession CacheType = "session"
|
||||
CacheTypeGeneral CacheType = "general"
|
||||
|
||||
// maxCacheEntrySize defines the maximum size for a single cache entry (64 MiB)
|
||||
// This prevents integer overflow when allocating memory for serialization
|
||||
maxCacheEntrySize = 64 * 1024 * 1024
|
||||
)
|
||||
|
||||
// UniversalCacheConfig provides configuration for the universal cache
|
||||
@@ -720,22 +724,6 @@ func (c *UniversalCache) SetWithMetadata(key string, value interface{}, ttl time
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTyped retrieves a typed value from the cache
|
||||
func GetTyped[T any](c *UniversalCache, key string) (T, bool) {
|
||||
var zero T
|
||||
value, exists := c.Get(key)
|
||||
if !exists {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
typed, ok := value.(T)
|
||||
if !ok {
|
||||
return zero, false
|
||||
}
|
||||
|
||||
return typed, true
|
||||
}
|
||||
|
||||
// TokenCacheOperations provides token-specific operations
|
||||
func (c *UniversalCache) BlacklistToken(token string, ttl time.Duration) error {
|
||||
if c.config.Type != CacheTypeToken {
|
||||
@@ -784,14 +772,81 @@ func (c *UniversalCache) Strategy() CacheStrategy {
|
||||
|
||||
// serialize converts a value to bytes for backend storage
|
||||
func (c *UniversalCache) serialize(value interface{}) ([]byte, error) {
|
||||
// Use JSON for serialization - simple and universal
|
||||
return json.Marshal(value)
|
||||
// If value is already a byte slice (e.g., pre-marshaled JSON from metadata_cache),
|
||||
// store it directly with a marker to prevent double-encoding.
|
||||
// This fixes the issue where []byte was being JSON-marshaled, causing Base64 encoding.
|
||||
if bytes, ok := value.([]byte); ok {
|
||||
// Validate size to prevent integer overflow
|
||||
if len(bytes) > maxCacheEntrySize {
|
||||
return nil, fmt.Errorf("cache entry size %d exceeds maximum allowed size %d", len(bytes), maxCacheEntrySize)
|
||||
}
|
||||
// Check for potential overflow when adding marker byte
|
||||
if len(bytes) == maxCacheEntrySize {
|
||||
return nil, fmt.Errorf("cache entry size would overflow when adding marker byte")
|
||||
}
|
||||
|
||||
// Prepend marker byte 0x00 to indicate raw bytes (not JSON-encoded)
|
||||
result := make([]byte, len(bytes)+1)
|
||||
result[0] = 0x00
|
||||
copy(result[1:], bytes)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// For all other types (maps, strings, etc.), use JSON encoding
|
||||
// Prepend marker byte 0x01 to indicate JSON-encoded data
|
||||
jsonData, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate size to prevent integer overflow
|
||||
if len(jsonData) > maxCacheEntrySize {
|
||||
return nil, fmt.Errorf("serialized cache entry size %d exceeds maximum allowed size %d", len(jsonData), maxCacheEntrySize)
|
||||
}
|
||||
// Check for potential overflow when adding marker byte
|
||||
if len(jsonData) == maxCacheEntrySize {
|
||||
return nil, fmt.Errorf("serialized cache entry size would overflow when adding marker byte")
|
||||
}
|
||||
|
||||
result := make([]byte, len(jsonData)+1)
|
||||
result[0] = 0x01
|
||||
copy(result[1:], jsonData)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// deserialize converts bytes from backend storage to a value
|
||||
func (c *UniversalCache) deserialize(data []byte, value interface{}) error {
|
||||
// Use JSON for deserialization
|
||||
return json.Unmarshal(data, value)
|
||||
if len(data) == 0 {
|
||||
return fmt.Errorf("cannot deserialize empty data")
|
||||
}
|
||||
|
||||
// Check for type marker (added by serialize)
|
||||
if data[0] == 0x00 {
|
||||
// Raw bytes - strip marker and return as-is
|
||||
rawBytes := data[1:]
|
||||
if ptr, ok := value.(*interface{}); ok {
|
||||
*ptr = rawBytes
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("cannot deserialize raw bytes into %T", value)
|
||||
}
|
||||
|
||||
if data[0] == 0x01 {
|
||||
// JSON-encoded - strip marker and unmarshal
|
||||
return json.Unmarshal(data[1:], value)
|
||||
}
|
||||
|
||||
// Legacy data without marker (for backward compatibility)
|
||||
// Try to unmarshal as JSON
|
||||
if err := json.Unmarshal(data, value); err != nil {
|
||||
// If unmarshal fails, treat as raw bytes
|
||||
if ptr, ok := value.(*interface{}); ok {
|
||||
*ptr = data
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// prefixKey adds a cache type prefix to the key for backend storage
|
||||
|
||||
@@ -0,0 +1,517 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestUniversalCache_SerializeDeserialize tests the fix for issue #116
|
||||
// where metadata was stored as Base64-encoded JSON but read as plain JSON
|
||||
func TestUniversalCache_SerializeDeserialize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RawBytesPreserved", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Test data: pre-marshaled JSON bytes (like metadata_cache uses)
|
||||
testData := []byte(`{"issuer":"https://example.com","jwks_uri":"https://example.com/jwks"}`)
|
||||
|
||||
// Serialize
|
||||
serialized, err := cache.serialize(testData)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, serialized)
|
||||
|
||||
// Should have marker byte
|
||||
assert.Equal(t, byte(0x00), serialized[0], "Should have raw bytes marker")
|
||||
assert.Equal(t, testData, serialized[1:], "Data should be preserved after marker")
|
||||
|
||||
// Deserialize
|
||||
var result interface{}
|
||||
err = cache.deserialize(serialized, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should get back []byte
|
||||
resultBytes, ok := result.([]byte)
|
||||
require.True(t, ok, "Result should be []byte")
|
||||
assert.Equal(t, testData, resultBytes, "Deserialized data should match original")
|
||||
})
|
||||
|
||||
t.Run("JSONEncodedTypes", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
}{
|
||||
{
|
||||
name: "Map",
|
||||
value: map[string]interface{}{"key": "value", "number": 42.0},
|
||||
},
|
||||
{
|
||||
name: "String",
|
||||
value: "test-string",
|
||||
},
|
||||
{
|
||||
name: "Number",
|
||||
value: 123.456,
|
||||
},
|
||||
{
|
||||
name: "Array",
|
||||
value: []interface{}{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Serialize
|
||||
serialized, err := cache.serialize(tc.value)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, serialized)
|
||||
|
||||
// Should have JSON marker byte
|
||||
assert.Equal(t, byte(0x01), serialized[0], "Should have JSON marker")
|
||||
|
||||
// Verify the JSON portion is valid
|
||||
var checkJSON interface{}
|
||||
err = json.Unmarshal(serialized[1:], &checkJSON)
|
||||
require.NoError(t, err, "Should be valid JSON after marker")
|
||||
|
||||
// Deserialize
|
||||
var result interface{}
|
||||
err = cache.deserialize(serialized, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare results (using JSON round-trip for consistent comparison)
|
||||
expectedJSON, _ := json.Marshal(tc.value)
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
assert.JSONEq(t, string(expectedJSON), string(resultJSON), "Deserialized data should match original")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LegacyDataCompatibility", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Simulate legacy data (JSON without marker byte)
|
||||
legacyData := []byte(`{"legacy":"data"}`)
|
||||
|
||||
var result interface{}
|
||||
err := cache.deserialize(legacyData, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should successfully unmarshal as JSON
|
||||
resultMap, ok := result.(map[string]interface{})
|
||||
require.True(t, ok, "Should unmarshal legacy JSON data")
|
||||
assert.Equal(t, "data", resultMap["legacy"])
|
||||
})
|
||||
|
||||
t.Run("EmptyDataHandling", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
var result interface{}
|
||||
err := cache.deserialize([]byte{}, &result)
|
||||
assert.Error(t, err, "Should error on empty data")
|
||||
assert.Contains(t, err.Error(), "empty data")
|
||||
})
|
||||
|
||||
t.Run("OverflowProtection_LargeBytes", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Create a byte slice that exceeds maxCacheEntrySize (64 MiB)
|
||||
oversizedBytes := make([]byte, 65*1024*1024) // 65 MiB
|
||||
|
||||
// Attempt to serialize - should fail with overflow error
|
||||
_, err := cache.serialize(oversizedBytes)
|
||||
require.Error(t, err, "Should error on oversized byte slice")
|
||||
assert.Contains(t, err.Error(), "exceeds maximum allowed size")
|
||||
})
|
||||
|
||||
t.Run("OverflowProtection_ExactMaxSize", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Create a byte slice exactly at maxCacheEntrySize
|
||||
// This should fail because adding marker byte would overflow
|
||||
exactMaxBytes := make([]byte, 64*1024*1024) // Exactly 64 MiB
|
||||
|
||||
_, err := cache.serialize(exactMaxBytes)
|
||||
require.Error(t, err, "Should error when adding marker would overflow")
|
||||
assert.Contains(t, err.Error(), "would overflow when adding marker byte")
|
||||
})
|
||||
|
||||
t.Run("OverflowProtection_SafeSize", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Create a byte slice well within limits
|
||||
safeBytes := make([]byte, 1024*1024) // 1 MiB - safe size
|
||||
|
||||
serialized, err := cache.serialize(safeBytes)
|
||||
require.NoError(t, err, "Should succeed with safe size")
|
||||
assert.NotNil(t, serialized)
|
||||
assert.Equal(t, len(safeBytes)+1, len(serialized), "Should add marker byte")
|
||||
})
|
||||
|
||||
t.Run("OverflowProtection_JSONData", func(t *testing.T) {
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
// Create a very large map that will exceed limits when JSON-encoded
|
||||
largeMap := make(map[string]string)
|
||||
// Each entry is roughly 50 bytes, so we need ~1.3M entries to exceed 64 MiB
|
||||
for i := 0; i < 1400000; i++ {
|
||||
key := fmt.Sprintf("key_%d", i)
|
||||
largeMap[key] = "value_with_some_content_to_make_it_larger"
|
||||
}
|
||||
|
||||
_, err := cache.serialize(largeMap)
|
||||
require.Error(t, err, "Should error when JSON serialization exceeds size limit")
|
||||
assert.Contains(t, err.Error(), "exceeds maximum allowed size")
|
||||
})
|
||||
}
|
||||
|
||||
// TestUniversalCache_RedisIntegration_Issue116 tests the complete fix for issue #116
|
||||
// with actual Redis backend to ensure metadata cache works correctly
|
||||
func TestUniversalCache_RedisIntegration_Issue116(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Start miniredis server
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
// Create Redis backend
|
||||
redisConfig := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisConfig.RedisPrefix = "test:"
|
||||
backend, err := backends.NewRedisBackend(redisConfig)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
t.Run("MetadataCache_StoreAndRetrieve", func(t *testing.T) {
|
||||
// Create cache with Redis backend
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeMetadata,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
// Simulate metadata_cache.Set behavior:
|
||||
// 1. Marshal metadata to JSON
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
JWKSURL: "https://example.com/jwks",
|
||||
TokenURL: "https://example.com/token",
|
||||
AuthURL: "https://example.com/authorize",
|
||||
}
|
||||
jsonData, err := json.Marshal(metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. Store the JSON bytes
|
||||
key := "v2:https://example.com"
|
||||
err = cache.Set(key, jsonData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 3. Retrieve the data
|
||||
retrieved, exists := cache.Get(key)
|
||||
require.True(t, exists, "Data should exist in cache")
|
||||
|
||||
// 4. Should get back []byte (not a string or map)
|
||||
retrievedBytes, ok := retrieved.([]byte)
|
||||
require.True(t, ok, "Retrieved value should be []byte, got %T", retrieved)
|
||||
|
||||
// 5. Should be able to unmarshal as JSON
|
||||
var retrievedMetadata ProviderMetadata
|
||||
err = json.Unmarshal(retrievedBytes, &retrievedMetadata)
|
||||
require.NoError(t, err, "Should be able to unmarshal retrieved bytes as JSON")
|
||||
|
||||
// 6. Verify data integrity
|
||||
assert.Equal(t, metadata.Issuer, retrievedMetadata.Issuer)
|
||||
assert.Equal(t, metadata.JWKSURL, retrievedMetadata.JWKSURL)
|
||||
assert.Equal(t, metadata.TokenURL, retrievedMetadata.TokenURL)
|
||||
})
|
||||
|
||||
t.Run("MetadataCache_NoBase64Encoding", func(t *testing.T) {
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeMetadata,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
// Store JSON bytes
|
||||
jsonData := []byte(`{"issuer":"https://test.com"}`)
|
||||
key := "v2:https://test.com"
|
||||
err = cache.Set(key, jsonData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve
|
||||
retrieved, exists := cache.Get(key)
|
||||
require.True(t, exists)
|
||||
|
||||
retrievedBytes, ok := retrieved.([]byte)
|
||||
require.True(t, ok)
|
||||
|
||||
// The retrieved data should NOT start with "eyJ" (Base64 encoding of "{")
|
||||
// This was the bug in issue #116
|
||||
assert.NotEqual(t, []byte("eyJ"), retrievedBytes[:3], "Data should not be Base64 encoded")
|
||||
|
||||
// Should be valid JSON
|
||||
var checkJSON map[string]interface{}
|
||||
err = json.Unmarshal(retrievedBytes, &checkJSON)
|
||||
require.NoError(t, err, "Data should be valid JSON")
|
||||
assert.Equal(t, "https://test.com", checkJSON["issuer"])
|
||||
})
|
||||
|
||||
t.Run("TokenCache_MapValues", func(t *testing.T) {
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
// Store a map (like TokenCache does)
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": 1234567890.0,
|
||||
"scope": "read write",
|
||||
}
|
||||
key := "token:abc123"
|
||||
err = cache.Set(key, claims, 10*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve
|
||||
retrieved, exists := cache.Get(key)
|
||||
require.True(t, exists)
|
||||
|
||||
// Should get back a map
|
||||
retrievedMap, ok := retrieved.(map[string]interface{})
|
||||
require.True(t, ok, "Retrieved value should be map[string]interface{}")
|
||||
assert.Equal(t, "user123", retrievedMap["sub"])
|
||||
assert.Equal(t, 1234567890.0, retrievedMap["exp"])
|
||||
})
|
||||
|
||||
t.Run("MixedTypes_SameCache", func(t *testing.T) {
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
// Store different types
|
||||
jsonBytes := []byte(`{"type":"json-bytes"}`)
|
||||
err = cache.Set("key1", jsonBytes, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
mapData := map[string]interface{}{"type": "map"}
|
||||
err = cache.Set("key2", mapData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
stringData := "plain-string"
|
||||
err = cache.Set("key3", stringData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify each type
|
||||
val1, exists := cache.Get("key1")
|
||||
require.True(t, exists)
|
||||
bytes1, ok := val1.([]byte)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, jsonBytes, bytes1)
|
||||
|
||||
val2, exists := cache.Get("key2")
|
||||
require.True(t, exists)
|
||||
map2, ok := val2.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "map", map2["type"])
|
||||
|
||||
val3, exists := cache.Get("key3")
|
||||
require.True(t, exists)
|
||||
str3, ok := val3.(string)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, stringData, str3)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUniversalCache_BackwardCompatibility tests that old cached data is handled gracefully
|
||||
func TestUniversalCache_BackwardCompatibility(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
redisConfig := backends.DefaultRedisConfig(mr.Addr())
|
||||
backend, err := backends.NewRedisBackend(redisConfig)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("LegacyJSONData", func(t *testing.T) {
|
||||
// Manually insert legacy data (plain JSON without marker)
|
||||
legacyKey := "general:legacy-key"
|
||||
legacyData := []byte(`{"old":"format"}`)
|
||||
err = backend.Set(ctx, legacyKey, legacyData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to retrieve via UniversalCache
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
retrieved, exists := cache.Get("legacy-key")
|
||||
require.True(t, exists, "Should retrieve legacy data")
|
||||
|
||||
// Should deserialize as JSON map
|
||||
retrievedMap, ok := retrieved.(map[string]interface{})
|
||||
require.True(t, ok, "Should unmarshal legacy JSON")
|
||||
assert.Equal(t, "format", retrievedMap["old"])
|
||||
})
|
||||
|
||||
t.Run("LegacyCorruptData", func(t *testing.T) {
|
||||
// Insert corrupt/invalid data
|
||||
corruptKey := "general:corrupt-key"
|
||||
corruptData := []byte("not json and no marker")
|
||||
err = backend.Set(ctx, corruptKey, corruptData, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
}, backend)
|
||||
defer cache.Close()
|
||||
|
||||
retrieved, exists := cache.Get("corrupt-key")
|
||||
require.True(t, exists)
|
||||
|
||||
// Should return as raw bytes (fallback)
|
||||
retrievedBytes, ok := retrieved.([]byte)
|
||||
require.True(t, ok, "Should return corrupt data as raw bytes")
|
||||
assert.Equal(t, corruptData, retrievedBytes)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMetadataCache_Issue116_Regression is the main regression test for issue #116
|
||||
// This specifically tests the scenario described in the GitHub issue
|
||||
func TestMetadataCache_Issue116_Regression(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
// Create Redis backend
|
||||
redisConfig := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisConfig.RedisPrefix = "traefik:"
|
||||
backend, err := backends.NewRedisBackend(redisConfig)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
// Create a simple logger
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
// Create metadata cache instance
|
||||
metadataCache := NewUniversalCacheWithBackend(UniversalCacheConfig{
|
||||
Type: CacheTypeMetadata,
|
||||
MaxSize: 100,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true,
|
||||
}, backend)
|
||||
defer metadataCache.Close()
|
||||
|
||||
// Use the actual MetadataCache wrapper
|
||||
wg := &sync.WaitGroup{}
|
||||
mc := &MetadataCache{
|
||||
cache: metadataCache,
|
||||
logger: logger,
|
||||
wg: wg,
|
||||
}
|
||||
|
||||
// Test: Store and retrieve metadata (the scenario from issue #116)
|
||||
providerURL := "https://example.com"
|
||||
metadata := &ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
AuthURL: "https://example.com/authorize",
|
||||
TokenURL: "https://example.com/token",
|
||||
JWKSURL: "https://example.com/jwks",
|
||||
RevokeURL: "https://example.com/revoke",
|
||||
EndSessionURL: "https://example.com/logout",
|
||||
RegistrationURL: "https://example.com/register",
|
||||
ScopesSupported: []string{"openid", "profile", "email"},
|
||||
}
|
||||
|
||||
// Store metadata
|
||||
err = mc.Set(providerURL, metadata, 1*time.Hour)
|
||||
require.NoError(t, err, "Should store metadata without error")
|
||||
|
||||
// Retrieve metadata
|
||||
retrieved, exists := mc.Get(providerURL)
|
||||
require.True(t, exists, "Should retrieve stored metadata")
|
||||
require.NotNil(t, retrieved, "Retrieved metadata should not be nil")
|
||||
|
||||
// Verify no corruption - this was failing in issue #116 with "invalid character 'e'" error
|
||||
assert.Equal(t, metadata.Issuer, retrieved.Issuer)
|
||||
assert.Equal(t, metadata.AuthURL, retrieved.AuthURL)
|
||||
assert.Equal(t, metadata.TokenURL, retrieved.TokenURL)
|
||||
assert.Equal(t, metadata.JWKSURL, retrieved.JWKSURL)
|
||||
|
||||
// Verify the data is not Base64-encoded in Redis
|
||||
// This checks the root cause mentioned in the issue
|
||||
ctx := context.Background()
|
||||
rawData, _, exists, err := backend.Get(ctx, "metadata:v2:"+providerURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
|
||||
// Strip the marker byte
|
||||
require.Greater(t, len(rawData), 1, "Data should have marker byte")
|
||||
dataWithoutMarker := rawData[1:]
|
||||
|
||||
// Should not start with "eyJ" (Base64 encoding of "{")
|
||||
if len(dataWithoutMarker) >= 3 {
|
||||
assert.NotEqual(t, "eyJ", string(dataWithoutMarker[:3]), "Data should not be Base64-encoded")
|
||||
}
|
||||
|
||||
// Should be valid JSON
|
||||
var checkMetadata ProviderMetadata
|
||||
err = json.Unmarshal(dataWithoutMarker, &checkMetadata)
|
||||
require.NoError(t, err, "Stored data should be valid JSON, not Base64")
|
||||
assert.Equal(t, metadata.Issuer, checkMetadata.Issuer)
|
||||
}
|
||||
@@ -13,20 +13,22 @@ import (
|
||||
// It runs a single consolidated cleanup goroutine for all caches, reducing
|
||||
// goroutine count and CPU overhead compared to per-cache cleanup routines.
|
||||
type UniversalCacheManager struct {
|
||||
sharedBackend backends.CacheBackend
|
||||
ctx context.Context
|
||||
tokenTypeCache *UniversalCache
|
||||
jwkCache *UniversalCache
|
||||
sessionCache *UniversalCache
|
||||
introspectionCache *UniversalCache
|
||||
tokenCache *UniversalCache
|
||||
metadataCache *UniversalCache
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
cleanupStarted bool
|
||||
sharedBackend backends.CacheBackend
|
||||
ctx context.Context
|
||||
tokenTypeCache *UniversalCache
|
||||
jwkCache *UniversalCache
|
||||
sessionCache *UniversalCache
|
||||
introspectionCache *UniversalCache
|
||||
tokenCache *UniversalCache
|
||||
metadataCache *UniversalCache
|
||||
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
|
||||
sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
cleanupStarted bool
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -169,6 +171,16 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// Initialize session invalidation cache for backchannel/front-channel logout
|
||||
// This cache stores invalidated session IDs and subjects to revoke sessions
|
||||
manager.sessionInvalidationCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeSession,
|
||||
MaxSize: 5000, // Support many concurrent invalidations
|
||||
DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
}
|
||||
|
||||
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
|
||||
@@ -349,6 +361,32 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// DCR credentials cache - CRITICAL for distributed DCR across multiple nodes
|
||||
// Uses Redis backend to share client credentials across all Traefik replicas
|
||||
manager.dcrCredentialsCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100, // Few providers expected
|
||||
DefaultTTL: 30 * 24 * time.Hour, // 30 days default (credentials are long-lived)
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("dcr"),
|
||||
)
|
||||
|
||||
// Session invalidation cache - CRITICAL for distributed backchannel/front-channel logout
|
||||
// Uses Redis backend to share session invalidations across all Traefik replicas
|
||||
manager.sessionInvalidationCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeSession,
|
||||
MaxSize: 5000, // Support many concurrent invalidations
|
||||
DefaultTTL: 25 * time.Hour, // Slightly longer than session max age (24h)
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("session_invalidation"),
|
||||
)
|
||||
|
||||
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
|
||||
}
|
||||
|
||||
@@ -396,6 +434,8 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
|
||||
m.sessionCache,
|
||||
m.introspectionCache,
|
||||
m.tokenTypeCache,
|
||||
m.dcrCredentialsCache,
|
||||
m.sessionInvalidationCache,
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
@@ -437,13 +477,6 @@ func (m *UniversalCacheManager) GetJWKCache() *UniversalCache {
|
||||
return m.jwkCache
|
||||
}
|
||||
|
||||
// GetSessionCache returns the session cache
|
||||
func (m *UniversalCacheManager) GetSessionCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.sessionCache
|
||||
}
|
||||
|
||||
// GetIntrospectionCache returns the token introspection cache
|
||||
func (m *UniversalCacheManager) GetIntrospectionCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
@@ -458,6 +491,20 @@ func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache {
|
||||
return m.tokenTypeCache
|
||||
}
|
||||
|
||||
// GetSessionInvalidationCache returns the session invalidation cache for backchannel/front-channel logout
|
||||
func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.sessionInvalidationCache
|
||||
}
|
||||
|
||||
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
|
||||
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.dcrCredentialsCache
|
||||
}
|
||||
|
||||
// Close shuts down all caches and the consolidated cleanup routine
|
||||
func (m *UniversalCacheManager) Close() error {
|
||||
// Stop the consolidated cleanup routine first
|
||||
@@ -473,7 +520,7 @@ func (m *UniversalCacheManager) Close() error {
|
||||
|
||||
// Close all caches first (they won't close the shared backend)
|
||||
for _, cache := range []*UniversalCache{
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache,
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache,
|
||||
} {
|
||||
if cache != nil {
|
||||
_ = cache.Close() // Safe to ignore: best effort cache cleanup
|
||||
@@ -494,35 +541,6 @@ func (m *UniversalCacheManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitializeCacheManagerFromConfig initializes the cache manager with configuration
|
||||
// This should be called early in the application startup with the loaded configuration
|
||||
func InitializeCacheManagerFromConfig(config *Config) *UniversalCacheManager {
|
||||
logger := NewLogger(config.LogLevel)
|
||||
|
||||
// Initialize Redis config if not present
|
||||
if config.Redis == nil {
|
||||
config.Redis = &RedisConfig{}
|
||||
}
|
||||
|
||||
// Apply environment variable fallbacks for fields not set in config
|
||||
// This allows env vars to be used as optional overrides only when
|
||||
// the config field is not explicitly set through Traefik
|
||||
config.Redis.ApplyEnvFallbacks()
|
||||
|
||||
// Apply defaults after env fallbacks
|
||||
config.Redis.ApplyDefaults()
|
||||
|
||||
// Log cache backend selection
|
||||
if config.Redis != nil && config.Redis.Enabled {
|
||||
logger.Infof("Initializing cache backend with Redis: mode=%s, address=%s",
|
||||
config.Redis.CacheMode, config.Redis.Address)
|
||||
} else {
|
||||
logger.Info("Initializing cache backend with memory-only mode")
|
||||
}
|
||||
|
||||
return GetUniversalCacheManagerWithConfig(logger, config.Redis)
|
||||
}
|
||||
|
||||
// ResetUniversalCacheManagerForTesting resets the singleton for testing purposes only
|
||||
// This should only be called in test code to ensure proper cleanup between tests
|
||||
func ResetUniversalCacheManagerForTesting() {
|
||||
|
||||
Reference in New Issue
Block a user