mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 06b219d1f8 | |||
| 413e4a1b7d | |||
| 69e0d98c67 |
+231
@@ -1630,3 +1630,234 @@ configuration:
|
||||
|
||||
Default: 30 seconds
|
||||
required: false
|
||||
|
||||
dynamicClientRegistration:
|
||||
type: object
|
||||
description: |
|
||||
Configuration for OIDC Dynamic Client Registration (RFC 7591/7592).
|
||||
|
||||
Dynamic Client Registration allows the middleware to automatically register
|
||||
itself as an OAuth 2.0 client with the OIDC provider, eliminating the need
|
||||
to manually create and manage client credentials.
|
||||
|
||||
This is particularly useful for:
|
||||
- Automated deployments where manual client creation is impractical
|
||||
- Multi-tenant scenarios requiring per-deployment client isolation
|
||||
- Development and testing environments
|
||||
- Kubernetes environments with multiple replicas
|
||||
|
||||
For multi-replica deployments (Kubernetes), enable Redis storage to share
|
||||
credentials across all instances and prevent registration race conditions.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
persistCredentials: true
|
||||
storageBackend: "redis" # Use Redis for distributed storage
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- https://app.example.com/oauth2/callback
|
||||
client_name: "My Application"
|
||||
application_type: "web"
|
||||
```
|
||||
required: false
|
||||
properties:
|
||||
enabled:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable dynamic client registration with the OIDC provider.
|
||||
When enabled and clientID is not set, the middleware will automatically
|
||||
register itself with the provider using the configuration in clientMetadata.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
persistCredentials:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable persistence of client credentials after registration.
|
||||
When enabled, credentials are saved to the configured storage backend
|
||||
and reloaded on restart to avoid re-registration.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
storageBackend:
|
||||
type: string
|
||||
description: |
|
||||
Storage backend for persisting DCR credentials.
|
||||
|
||||
Options:
|
||||
- "file": Store credentials in a local file (default for backward compatibility)
|
||||
- "redis": Store credentials in Redis (recommended for multi-replica deployments)
|
||||
- "auto": Use Redis if available, fall back to file storage
|
||||
|
||||
For Kubernetes deployments with multiple replicas, use "redis" to ensure
|
||||
all instances share the same client credentials and prevent registration
|
||||
race conditions where each replica registers its own client.
|
||||
|
||||
Default: "auto"
|
||||
required: false
|
||||
enum:
|
||||
- file
|
||||
- redis
|
||||
- auto
|
||||
|
||||
credentialsFile:
|
||||
type: string
|
||||
description: |
|
||||
Path to store client credentials when using file-based storage.
|
||||
The file will be created with restrictive permissions (0600).
|
||||
|
||||
Default: "/tmp/oidc-client-credentials.json"
|
||||
required: false
|
||||
|
||||
redisKeyPrefix:
|
||||
type: string
|
||||
description: |
|
||||
Prefix for Redis keys when using Redis storage.
|
||||
Useful for isolating credentials between different applications
|
||||
or environments sharing the same Redis instance.
|
||||
|
||||
Default: "dcr:creds:"
|
||||
required: false
|
||||
|
||||
registrationEndpoint:
|
||||
type: string
|
||||
description: |
|
||||
Override the registration endpoint URL.
|
||||
If not specified, the endpoint will be discovered from provider metadata.
|
||||
|
||||
Some providers may not advertise their registration endpoint in metadata,
|
||||
in which case you need to specify it explicitly.
|
||||
|
||||
Example: "https://auth.example.com/oauth/register"
|
||||
required: false
|
||||
|
||||
initialAccessToken:
|
||||
type: string
|
||||
description: |
|
||||
Initial Access Token for protected registration endpoints.
|
||||
Some providers require an access token to authorize client registration.
|
||||
|
||||
If your provider requires authentication for registration, obtain an
|
||||
initial access token from the provider and configure it here.
|
||||
|
||||
For Kubernetes, you can use secret references:
|
||||
urn:k8s:secret:namespace:secret-name:key
|
||||
required: false
|
||||
|
||||
clientMetadata:
|
||||
type: object
|
||||
description: |
|
||||
Client metadata to include in the registration request (RFC 7591).
|
||||
This defines the properties of the OAuth 2.0 client to be registered.
|
||||
required: false
|
||||
properties:
|
||||
redirect_uris:
|
||||
type: array
|
||||
description: |
|
||||
Array of redirect URIs for the client. Required for registration.
|
||||
These must match the callback URLs that will be used in authentication flows.
|
||||
|
||||
Example: ["https://app.example.com/oauth2/callback"]
|
||||
required: true
|
||||
items:
|
||||
type: string
|
||||
|
||||
client_name:
|
||||
type: string
|
||||
description: |
|
||||
Human-readable name of the client.
|
||||
This is typically displayed in consent screens.
|
||||
|
||||
Example: "My Application"
|
||||
required: false
|
||||
|
||||
application_type:
|
||||
type: string
|
||||
description: |
|
||||
Type of application. Affects security defaults.
|
||||
|
||||
Options:
|
||||
- "web": Server-side web application (default)
|
||||
- "native": Native/mobile application
|
||||
|
||||
Default: "web"
|
||||
required: false
|
||||
|
||||
grant_types:
|
||||
type: array
|
||||
description: |
|
||||
OAuth 2.0 grant types the client will use.
|
||||
|
||||
Default: ["authorization_code", "refresh_token"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
response_types:
|
||||
type: array
|
||||
description: |
|
||||
OAuth 2.0 response types the client will use.
|
||||
|
||||
Default: ["code"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
token_endpoint_auth_method:
|
||||
type: string
|
||||
description: |
|
||||
Authentication method for the token endpoint.
|
||||
|
||||
Options:
|
||||
- "client_secret_basic": HTTP Basic authentication (default)
|
||||
- "client_secret_post": Client credentials in POST body
|
||||
- "none": Public client (no authentication)
|
||||
|
||||
Default: "client_secret_basic"
|
||||
required: false
|
||||
|
||||
scope:
|
||||
type: string
|
||||
description: |
|
||||
Space-separated list of scopes the client is authorized to request.
|
||||
|
||||
Example: "openid profile email"
|
||||
required: false
|
||||
|
||||
contacts:
|
||||
type: array
|
||||
description: |
|
||||
Array of contact email addresses for the client administrator.
|
||||
|
||||
Example: ["admin@example.com"]
|
||||
required: false
|
||||
items:
|
||||
type: string
|
||||
|
||||
logo_uri:
|
||||
type: string
|
||||
description: |
|
||||
URL to the client's logo image for consent screens.
|
||||
required: false
|
||||
|
||||
client_uri:
|
||||
type: string
|
||||
description: |
|
||||
URL to the client's home page.
|
||||
required: false
|
||||
|
||||
policy_uri:
|
||||
type: string
|
||||
description: |
|
||||
URL to the client's privacy policy.
|
||||
required: false
|
||||
|
||||
tos_uri:
|
||||
type: string
|
||||
description: |
|
||||
URL to the client's terms of service.
|
||||
required: false
|
||||
|
||||
@@ -8,7 +8,7 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
|
||||
|
||||
- **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more
|
||||
- **Automatic provider detection**: Automatically detects and configures provider-specific settings
|
||||
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration
|
||||
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration, with Redis storage support for multi-replica deployments
|
||||
- **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes
|
||||
- **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles
|
||||
- **Domain restrictions**: Limit access to specific email domains or individual users
|
||||
|
||||
+13
-6
@@ -61,7 +61,7 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
|
||||
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedTokenCache returns the shared token cache
|
||||
@@ -93,7 +93,7 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
|
||||
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedTokenTypeCache returns the shared token type cache
|
||||
@@ -101,7 +101,7 @@ func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
@@ -121,7 +121,8 @@ func CleanupGlobalCacheManager() error {
|
||||
|
||||
// CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface
|
||||
type CacheInterfaceWrapper struct {
|
||||
cache *UniversalCache
|
||||
cache *UniversalCache
|
||||
managed bool // If true, cache is managed globally and Close() is a no-op
|
||||
}
|
||||
|
||||
// Set stores a value
|
||||
@@ -149,9 +150,15 @@ func (c *CacheInterfaceWrapper) Cleanup() {
|
||||
c.cache.Cleanup()
|
||||
}
|
||||
|
||||
// Close shuts down the cache
|
||||
// Close shuts down the cache if it's not managed globally.
|
||||
// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding
|
||||
// when multiple plugin instances are closed during Traefik configuration reloads.
|
||||
func (c *CacheInterfaceWrapper) Close() {
|
||||
// Close the underlying cache to stop goroutines
|
||||
if c.managed {
|
||||
// Cache is managed globally by UniversalCacheManager, so we don't close it here.
|
||||
return
|
||||
}
|
||||
// Standalone cache - close it properly to stop cleanup goroutines
|
||||
if c.cache != nil {
|
||||
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
|
||||
}
|
||||
|
||||
+153
@@ -219,6 +219,159 @@ func TestCacheInterfaceWrapper_Close(t *testing.T) {
|
||||
nilWrapper.Close()
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_ManagedClose_Regression tests that managed cache wrappers
|
||||
// don't close the underlying cache when Close() is called. This is a regression test
|
||||
// for issue #105 where multiple plugin instances closing shared caches caused log flooding.
|
||||
func TestCacheInterfaceWrapper_ManagedClose_Regression(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Get a managed cache wrapper
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Verify it's marked as managed
|
||||
if !wrapper.managed {
|
||||
t.Error("Expected shared cache wrapper to be marked as managed")
|
||||
}
|
||||
|
||||
// Set some data before Close
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Close the wrapper (should be a no-op for managed caches)
|
||||
wrapper.Close()
|
||||
|
||||
// Verify the cache is still operational after Close
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected cache to still work after Close() on managed wrapper")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
|
||||
// Can still set new values
|
||||
cache.Set("new-key", "new-value", time.Hour)
|
||||
newValue, found := cache.Get("new-key")
|
||||
if !found || newValue != "new-value" {
|
||||
t.Error("Expected to be able to set new values after Close() on managed wrapper")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_StandaloneClose tests that standalone cache wrappers
|
||||
// properly close the underlying cache when Close() is called.
|
||||
func TestCacheInterfaceWrapper_StandaloneClose(t *testing.T) {
|
||||
// Create a standalone cache (not from the global cache manager)
|
||||
standaloneCache := NewCache()
|
||||
|
||||
wrapper, ok := standaloneCache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Verify it's NOT marked as managed
|
||||
if wrapper.managed {
|
||||
t.Error("Expected standalone cache wrapper to NOT be marked as managed")
|
||||
}
|
||||
|
||||
// Set some data
|
||||
standaloneCache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Get baseline goroutine count
|
||||
baselineGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Close the wrapper (should actually close the underlying cache)
|
||||
wrapper.Close()
|
||||
|
||||
// Give cleanup goroutine time to stop
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Goroutine count should decrease (cleanup routine stopped)
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > baselineGoroutines {
|
||||
// This is acceptable - other tests might have started goroutines
|
||||
t.Logf("Goroutine count: baseline=%d, final=%d", baselineGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_MultipleInstancesClose_Regression tests that multiple
|
||||
// plugin instances can close their cache wrappers without affecting shared caches.
|
||||
// This is a regression test for issue #105.
|
||||
func TestCacheInterfaceWrapper_MultipleInstancesClose_Regression(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Simulate multiple plugin instances getting cache references
|
||||
instances := make([]*CacheInterfaceWrapper, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
instances[i] = wrapper
|
||||
|
||||
// Each instance might set some data
|
||||
cache.Set(fmt.Sprintf("instance-%d-key", i), fmt.Sprintf("value-%d", i), time.Hour)
|
||||
}
|
||||
|
||||
// Close all instances (simulating plugin shutdown/reload)
|
||||
for _, wrapper := range instances {
|
||||
wrapper.Close()
|
||||
}
|
||||
|
||||
// The shared cache should still work after all instances closed their wrappers
|
||||
newCache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Data set by earlier instances should still be accessible
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("instance-%d-key", i)
|
||||
value, found := newCache.Get(key)
|
||||
if !found {
|
||||
t.Errorf("Expected data from instance %d to still be accessible", i)
|
||||
}
|
||||
expectedValue := fmt.Sprintf("value-%d", i)
|
||||
if value != expectedValue {
|
||||
t.Errorf("Expected '%s', got '%v'", expectedValue, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Should be able to add new data
|
||||
newCache.Set("after-close-key", "after-close-value", time.Hour)
|
||||
value, found := newCache.Get("after-close-key")
|
||||
if !found || value != "after-close-value" {
|
||||
t.Error("Expected to be able to use cache after all wrapper Close() calls")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllSharedCachesMarkedAsManaged verifies all shared cache getters
|
||||
// return managed wrappers to prevent the log flooding issue.
|
||||
func TestAllSharedCachesMarkedAsManaged(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cache CacheInterface
|
||||
}{
|
||||
{"TokenBlacklist", cm.GetSharedTokenBlacklist()},
|
||||
{"IntrospectionCache", cm.GetSharedIntrospectionCache()},
|
||||
{"TokenTypeCache", cm.GetSharedTokenTypeCache()},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
wrapper, ok := tt.cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatalf("Expected CacheInterfaceWrapper for %s", tt.name)
|
||||
}
|
||||
if !wrapper.managed {
|
||||
t.Errorf("%s cache wrapper should be marked as managed", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
@@ -0,0 +1,290 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/dcrstorage"
|
||||
)
|
||||
|
||||
// DCRStorageBackend represents the type of storage backend for DCR credentials.
|
||||
// Alias for internal package type for backward compatibility.
|
||||
type DCRStorageBackend = dcrstorage.StorageBackend
|
||||
|
||||
const (
|
||||
// DCRStorageBackendFile uses file-based storage (default for backward compatibility)
|
||||
DCRStorageBackendFile DCRStorageBackend = dcrstorage.StorageBackendFile
|
||||
|
||||
// DCRStorageBackendRedis uses Redis for distributed storage
|
||||
DCRStorageBackendRedis DCRStorageBackend = dcrstorage.StorageBackendRedis
|
||||
|
||||
// DCRStorageBackendAuto automatically selects Redis if available, otherwise file
|
||||
DCRStorageBackendAuto DCRStorageBackend = dcrstorage.StorageBackendAuto
|
||||
)
|
||||
|
||||
// DCRCredentialsStore defines the interface for storing DCR credentials.
|
||||
// This abstraction allows different storage backends (file, Redis) to be used
|
||||
// for persisting OIDC Dynamic Client Registration credentials across nodes.
|
||||
type DCRCredentialsStore interface {
|
||||
// Save stores the client registration response for a provider
|
||||
// The providerURL is used as a key to support multi-tenant scenarios
|
||||
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
|
||||
|
||||
// Load retrieves stored credentials for a provider
|
||||
// Returns nil, nil if no credentials exist (not an error)
|
||||
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
|
||||
|
||||
// Delete removes stored credentials for a provider
|
||||
Delete(ctx context.Context, providerURL string) error
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
Exists(ctx context.Context, providerURL string) (bool, error)
|
||||
}
|
||||
|
||||
// loggerAdapter adapts our Logger to the dcrstorage.Logger interface
|
||||
type loggerAdapter struct {
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
func (l *loggerAdapter) Debug(msg string) { l.logger.Debug("%s", msg) }
|
||||
func (l *loggerAdapter) Debugf(format string, args ...any) { l.logger.Debugf(format, args...) }
|
||||
func (l *loggerAdapter) Info(msg string) { l.logger.Info("%s", msg) }
|
||||
func (l *loggerAdapter) Infof(format string, args ...any) { l.logger.Infof(format, args...) }
|
||||
func (l *loggerAdapter) Error(msg string) { l.logger.Error("%s", msg) }
|
||||
func (l *loggerAdapter) Errorf(format string, args ...any) { l.logger.Errorf(format, args...) }
|
||||
|
||||
// cacheAdapter adapts UniversalCache to dcrstorage.Cache interface
|
||||
type cacheAdapter struct {
|
||||
cache *UniversalCache
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Get(key string) (any, bool) {
|
||||
return c.cache.Get(key)
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Set(key string, value any, ttl time.Duration) error {
|
||||
return c.cache.Set(key, value, ttl)
|
||||
}
|
||||
|
||||
func (c *cacheAdapter) Delete(key string) {
|
||||
c.cache.Delete(key)
|
||||
}
|
||||
|
||||
// fileStoreWrapper wraps dcrstorage.FileStore to implement DCRCredentialsStore
|
||||
type fileStoreWrapper struct {
|
||||
inner *dcrstorage.FileStore
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
innerCreds := convertCredsToInternal(creds)
|
||||
return w.inner.Save(ctx, providerURL, innerCreds)
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
innerCreds, err := w.inner.Load(ctx, providerURL)
|
||||
if err != nil || innerCreds == nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertCredsFromInternal(innerCreds), nil
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Delete(ctx context.Context, providerURL string) error {
|
||||
return w.inner.Delete(ctx, providerURL)
|
||||
}
|
||||
|
||||
func (w *fileStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
return w.inner.Exists(ctx, providerURL)
|
||||
}
|
||||
|
||||
// basePath returns the base path used for storing credentials (for backward compatibility in tests)
|
||||
func (w *fileStoreWrapper) basePath() string {
|
||||
return w.inner.BasePath()
|
||||
}
|
||||
|
||||
// getFilePath returns the file path for storing credentials for a specific provider (for backward compatibility in tests)
|
||||
func (w *fileStoreWrapper) getFilePath(providerURL string) string {
|
||||
return w.inner.GetFilePath(providerURL)
|
||||
}
|
||||
|
||||
// redisStoreWrapper wraps dcrstorage.RedisStore to implement DCRCredentialsStore
|
||||
type redisStoreWrapper struct {
|
||||
inner *dcrstorage.RedisStore
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
innerCreds := convertCredsToInternal(creds)
|
||||
return w.inner.Save(ctx, providerURL, innerCreds)
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
innerCreds, err := w.inner.Load(ctx, providerURL)
|
||||
if err != nil || innerCreds == nil {
|
||||
return nil, err
|
||||
}
|
||||
return convertCredsFromInternal(innerCreds), nil
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Delete(ctx context.Context, providerURL string) error {
|
||||
return w.inner.Delete(ctx, providerURL)
|
||||
}
|
||||
|
||||
func (w *redisStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
return w.inner.Exists(ctx, providerURL)
|
||||
}
|
||||
|
||||
// FileCredentialsStore implements DCRCredentialsStore using file-based storage.
|
||||
// This is the default storage backend for backward compatibility with existing deployments.
|
||||
type FileCredentialsStore = fileStoreWrapper
|
||||
|
||||
// RedisCredentialsStore implements DCRCredentialsStore using Redis-backed cache.
|
||||
// This storage backend enables sharing DCR credentials across multiple Traefik instances.
|
||||
type RedisCredentialsStore = redisStoreWrapper
|
||||
|
||||
// NewFileCredentialsStore creates a new file-based credentials store.
|
||||
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
|
||||
func NewFileCredentialsStore(basePath string, logger *Logger) *FileCredentialsStore {
|
||||
var dcrLogger dcrstorage.Logger
|
||||
if logger != nil {
|
||||
dcrLogger = &loggerAdapter{logger: logger}
|
||||
}
|
||||
inner := dcrstorage.NewFileStore(basePath, dcrLogger)
|
||||
return &fileStoreWrapper{inner: inner}
|
||||
}
|
||||
|
||||
// NewRedisCredentialsStore creates a new Redis-backed credentials store.
|
||||
// The cache should be configured with a Redis backend for distributed storage.
|
||||
// If keyPrefix is empty, defaults to "dcr:creds:"
|
||||
func NewRedisCredentialsStore(cache *UniversalCache, keyPrefix string, logger *Logger) *RedisCredentialsStore {
|
||||
var dcrLogger dcrstorage.Logger
|
||||
if logger != nil {
|
||||
dcrLogger = &loggerAdapter{logger: logger}
|
||||
}
|
||||
cacheAdapt := &cacheAdapter{cache: cache}
|
||||
inner := dcrstorage.NewRedisStore(cacheAdapt, keyPrefix, dcrLogger)
|
||||
return &redisStoreWrapper{inner: inner}
|
||||
}
|
||||
|
||||
// Helper functions to convert between main package and internal package types
|
||||
func convertCredsToInternal(creds *ClientRegistrationResponse) *dcrstorage.ClientRegistrationResponse {
|
||||
if creds == nil {
|
||||
return nil
|
||||
}
|
||||
return &dcrstorage.ClientRegistrationResponse{
|
||||
SubjectType: creds.SubjectType,
|
||||
LogoURI: creds.LogoURI,
|
||||
RegistrationAccessToken: creds.RegistrationAccessToken,
|
||||
RegistrationClientURI: creds.RegistrationClientURI,
|
||||
Scope: creds.Scope,
|
||||
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
|
||||
TOSURI: creds.TOSURI,
|
||||
PolicyURI: creds.PolicyURI,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
ApplicationType: creds.ApplicationType,
|
||||
ClientID: creds.ClientID,
|
||||
ClientName: creds.ClientName,
|
||||
JWKSURI: creds.JWKSURI,
|
||||
ClientURI: creds.ClientURI,
|
||||
Contacts: creds.Contacts,
|
||||
GrantTypes: creds.GrantTypes,
|
||||
ResponseTypes: creds.ResponseTypes,
|
||||
RedirectURIs: creds.RedirectURIs,
|
||||
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
|
||||
ClientIDIssuedAt: creds.ClientIDIssuedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func convertCredsFromInternal(creds *dcrstorage.ClientRegistrationResponse) *ClientRegistrationResponse {
|
||||
if creds == nil {
|
||||
return nil
|
||||
}
|
||||
return &ClientRegistrationResponse{
|
||||
SubjectType: creds.SubjectType,
|
||||
LogoURI: creds.LogoURI,
|
||||
RegistrationAccessToken: creds.RegistrationAccessToken,
|
||||
RegistrationClientURI: creds.RegistrationClientURI,
|
||||
Scope: creds.Scope,
|
||||
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
|
||||
TOSURI: creds.TOSURI,
|
||||
PolicyURI: creds.PolicyURI,
|
||||
ClientSecret: creds.ClientSecret,
|
||||
ApplicationType: creds.ApplicationType,
|
||||
ClientID: creds.ClientID,
|
||||
ClientName: creds.ClientName,
|
||||
JWKSURI: creds.JWKSURI,
|
||||
ClientURI: creds.ClientURI,
|
||||
Contacts: creds.Contacts,
|
||||
GrantTypes: creds.GrantTypes,
|
||||
ResponseTypes: creds.ResponseTypes,
|
||||
RedirectURIs: creds.RedirectURIs,
|
||||
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
|
||||
ClientIDIssuedAt: creds.ClientIDIssuedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDCRCredentialsStore creates a DCRCredentialsStore based on configuration.
|
||||
// This factory function handles backend selection logic:
|
||||
// - "file": Use file-based storage (default for backward compatibility)
|
||||
// - "redis": Use Redis exclusively (fails if Redis unavailable)
|
||||
// - "auto": Use Redis if available, fallback to file
|
||||
func NewDCRCredentialsStore(
|
||||
config *DynamicClientRegistrationConfig,
|
||||
cacheManager *CacheManager,
|
||||
logger *Logger,
|
||||
) (DCRCredentialsStore, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("DCR config is nil")
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
backend := config.StorageBackend
|
||||
if backend == "" {
|
||||
backend = string(DCRStorageBackendAuto) // Default to auto selection
|
||||
}
|
||||
|
||||
switch DCRStorageBackend(backend) {
|
||||
case DCRStorageBackendFile:
|
||||
logger.Info("Using file-based storage for DCR credentials")
|
||||
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
|
||||
|
||||
case DCRStorageBackendRedis:
|
||||
cache := getDCRCache(cacheManager)
|
||||
if cache == nil {
|
||||
return nil, fmt.Errorf("redis storage requested but Redis/cache not configured")
|
||||
}
|
||||
logger.Info("Using Redis storage for DCR credentials")
|
||||
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
|
||||
|
||||
case DCRStorageBackendAuto:
|
||||
// Try Redis first, fallback to file
|
||||
cache := getDCRCache(cacheManager)
|
||||
if cache != nil && cache.backend != nil {
|
||||
logger.Info("Auto-selected Redis storage for DCR credentials")
|
||||
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
|
||||
}
|
||||
logger.Info("Redis not available, using file storage for DCR credentials")
|
||||
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown DCR storage backend: %s", backend)
|
||||
}
|
||||
}
|
||||
|
||||
// getDCRCache safely retrieves the DCR credentials cache from the cache manager
|
||||
func getDCRCache(cacheManager *CacheManager) *UniversalCache {
|
||||
if cacheManager == nil {
|
||||
return nil
|
||||
}
|
||||
cacheManager.mu.RLock()
|
||||
defer cacheManager.mu.RUnlock()
|
||||
|
||||
if cacheManager.manager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cacheManager.manager.GetDCRCredentialsCache()
|
||||
}
|
||||
@@ -0,0 +1,663 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestFileCredentialsStore_SaveLoad tests the file-based credentials store
|
||||
func TestFileCredentialsStore_SaveLoad(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp directory for test files
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "test-access-token",
|
||||
RegistrationClientURI: "https://example.com/register/test-client-id",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||||
ResponseTypes: []string{"code"},
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
// Save credentials
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
// Load credentials
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
|
||||
// Verify fields
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
|
||||
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
tempDir2 := t.TempDir()
|
||||
store2 := NewFileCredentialsStore(filepath.Join(tempDir2, "nonexistent.json"), logger)
|
||||
|
||||
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent file: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
|
||||
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Error("Expected credentials to not exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete non-existent credentials", func(t *testing.T) {
|
||||
// Should not error
|
||||
err := store.Delete(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete should not error for non-existent: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_MultiProvider tests multi-provider support
|
||||
func TestFileCredentialsStore_MultiProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
provider1 := "https://auth1.example.com"
|
||||
provider2 := "https://auth2.example.com"
|
||||
|
||||
creds1 := &ClientRegistrationResponse{
|
||||
ClientID: "client-1",
|
||||
ClientSecret: "secret-1",
|
||||
}
|
||||
creds2 := &ClientRegistrationResponse{
|
||||
ClientID: "client-2",
|
||||
ClientSecret: "secret-2",
|
||||
}
|
||||
|
||||
// Save credentials for both providers
|
||||
if err := store.Save(ctx, provider1, creds1); err != nil {
|
||||
t.Fatalf("Failed to save creds1: %v", err)
|
||||
}
|
||||
if err := store.Save(ctx, provider2, creds2); err != nil {
|
||||
t.Fatalf("Failed to save creds2: %v", err)
|
||||
}
|
||||
|
||||
// Load and verify each provider's credentials
|
||||
loaded1, err := store.Load(ctx, provider1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds1: %v", err)
|
||||
}
|
||||
if loaded1.ClientID != "client-1" {
|
||||
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
|
||||
}
|
||||
|
||||
loaded2, err := store.Load(ctx, provider2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds2: %v", err)
|
||||
}
|
||||
if loaded2.ClientID != "client-2" {
|
||||
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
|
||||
}
|
||||
|
||||
// Delete one shouldn't affect the other
|
||||
if err := store.Delete(ctx, provider1); err != nil {
|
||||
t.Fatalf("Failed to delete creds1: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, provider2)
|
||||
if !exists {
|
||||
t.Error("Provider 2 credentials should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_ConcurrentAccess tests thread safety
|
||||
func TestFileCredentialsStore_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
concurrency := 10
|
||||
|
||||
// Concurrent saves
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = store.Save(ctx, providerURL, creds)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Concurrent loads
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = store.Load(ctx, providerURL)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Final verification
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after concurrent access: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test-client" {
|
||||
t.Error("Credentials corrupted after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_InvalidInput tests error handling
|
||||
func TestFileCredentialsStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty provider URL uses default path", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
err := store.Save(ctx, "", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Save with empty provider URL failed: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Load with empty provider URL failed: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials with empty provider URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_DefaultPath tests default path behavior
|
||||
func TestFileCredentialsStore_DefaultPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore("", logger)
|
||||
|
||||
// Just verify we can create with empty path and it has a default
|
||||
if store.basePath() == "" {
|
||||
t.Error("Expected default base path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_WithMemoryCache tests Redis store with in-memory cache
|
||||
func TestRedisCredentialsStore_WithMemoryCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create an in-memory cache for testing
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "redis-test-client",
|
||||
ClientSecret: "redis-test-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "redis-test-token",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
}
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_TTLFromExpiry tests TTL calculation
|
||||
func TestRedisCredentialsStore_TTLFromExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("expired credentials should fail", func(t *testing.T) {
|
||||
expiredCreds := &ClientRegistrationResponse{
|
||||
ClientID: "expired-client",
|
||||
ClientSecret: "expired-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), // Already expired
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "no-expiry-client",
|
||||
ClientSecret: "no-expiry-secret",
|
||||
ClientSecretExpiresAt: 0, // No expiry
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://noexpiry.example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials without expiry: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisCredentialsStore_InvalidInput tests error handling
|
||||
func TestRedisCredentialsStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: time.Hour,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewRedisCredentialsStore(cache, "", logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDCRStorageFactory tests the factory function
|
||||
func TestDCRStorageFactory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("nil config returns error", func(t *testing.T) {
|
||||
_, err := NewDCRCredentialsStore(nil, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("file backend creates file store", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "file",
|
||||
CredentialsFile: "/tmp/test-creds.json",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create file store: %v", err)
|
||||
}
|
||||
if store == nil {
|
||||
t.Error("Expected store but got nil")
|
||||
}
|
||||
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("redis backend without cache manager returns error", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "redis",
|
||||
}
|
||||
|
||||
_, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for redis backend without cache manager")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auto backend without redis falls back to file", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "auto",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create auto store: %v", err)
|
||||
}
|
||||
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore for auto without redis")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown backend returns error", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "unknown",
|
||||
}
|
||||
|
||||
_, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err == nil {
|
||||
t.Error("Expected error for unknown backend")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty backend defaults to auto", func(t *testing.T) {
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
StorageBackend: "",
|
||||
}
|
||||
|
||||
store, err := NewDCRCredentialsStore(config, nil, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create store with empty backend: %v", err)
|
||||
}
|
||||
|
||||
// Should default to file (auto without redis)
|
||||
_, ok := store.(*FileCredentialsStore)
|
||||
if !ok {
|
||||
t.Error("Expected FileCredentialsStore for empty backend")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDynamicClientRegistrar_WithStore tests registrar with store
|
||||
func TestDynamicClientRegistrar_WithStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrarWithStore(
|
||||
nil, // httpClient
|
||||
logger,
|
||||
config,
|
||||
"https://auth.example.com",
|
||||
store,
|
||||
)
|
||||
|
||||
if registrar == nil {
|
||||
t.Fatal("Expected registrar but got nil")
|
||||
}
|
||||
|
||||
if registrar.store == nil {
|
||||
t.Error("Expected store to be set")
|
||||
}
|
||||
|
||||
// Test SetStore
|
||||
newStore := NewFileCredentialsStore(filepath.Join(tempDir, "new.json"), logger)
|
||||
registrar.SetStore(newStore)
|
||||
|
||||
if registrar.store != newStore {
|
||||
t.Error("SetStore did not update the store")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDynamicClientRegistrar_CredentialsFromStore tests loading from store
|
||||
func TestDynamicClientRegistrar_CredentialsFromStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
providerURL := "https://auth.example.com"
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-save credentials
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "pre-saved-client",
|
||||
ClientSecret: "pre-saved-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
}
|
||||
if err := store.Save(ctx, providerURL, testCreds); err != nil {
|
||||
t.Fatalf("Failed to pre-save credentials: %v", err)
|
||||
}
|
||||
|
||||
config := &DynamicClientRegistrationConfig{
|
||||
Enabled: true,
|
||||
PersistCredentials: true,
|
||||
}
|
||||
|
||||
registrar := NewDynamicClientRegistrarWithStore(
|
||||
nil,
|
||||
logger,
|
||||
config,
|
||||
providerURL,
|
||||
store,
|
||||
)
|
||||
|
||||
// Test loading via the internal method
|
||||
loaded, err := registrar.loadCredentialsFromStore(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load from store: %v", err)
|
||||
}
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != "pre-saved-client" {
|
||||
t.Errorf("ClientID mismatch: got %s", loaded.ClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_CorruptedFile tests handling of corrupted files
|
||||
func TestFileCredentialsStore_CorruptedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(basePath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
// Write corrupted JSON
|
||||
filePath := store.getFilePath(providerURL)
|
||||
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write corrupted file: %v", err)
|
||||
}
|
||||
|
||||
// Should return error for corrupted file
|
||||
_, err := store.Load(ctx, providerURL)
|
||||
if err == nil {
|
||||
t.Error("Expected error for corrupted JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileCredentialsStore_DirectoryCreation tests auto directory creation
|
||||
func TestFileCredentialsStore_DirectoryCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
|
||||
logger := GetSingletonNoOpLogger()
|
||||
store := NewFileCredentialsStore(deepPath, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
|
||||
err := store.Save(ctx, "https://example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save with nested directory: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after nested directory creation: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials from nested directory")
|
||||
}
|
||||
}
|
||||
+34
-1
@@ -384,10 +384,14 @@ scopes:
|
||||
|
||||
### Dynamic Client Registration (RFC 7591)
|
||||
|
||||
Dynamic Client Registration allows the middleware to automatically register itself with the OIDC provider, eliminating the need to manually create client credentials.
|
||||
|
||||
**Basic Configuration (Single Instance):**
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
initialAccessToken: "your-token" # Optional
|
||||
initialAccessToken: "your-token" # Optional, if provider requires it
|
||||
persistCredentials: true
|
||||
credentialsFile: "/tmp/oidc-credentials.json"
|
||||
clientMetadata:
|
||||
@@ -400,6 +404,35 @@ dynamicClientRegistration:
|
||||
- "refresh_token"
|
||||
```
|
||||
|
||||
**Multi-Replica Deployment (Kubernetes):**
|
||||
|
||||
For Kubernetes deployments with multiple replicas, use Redis storage to share credentials across all instances and prevent registration race conditions:
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
persistCredentials: true
|
||||
storageBackend: "redis" # Share credentials via Redis
|
||||
redisKeyPrefix: "myapp:dcr:" # Optional custom prefix
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://your-app.com/oauth2/callback"
|
||||
client_name: "My Application"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "redis"
|
||||
```
|
||||
|
||||
**Storage Backend Options:**
|
||||
|
||||
| Backend | Description | Use Case |
|
||||
|---------|-------------|----------|
|
||||
| `file` | Store credentials in local file | Single instance deployments |
|
||||
| `redis` | Store credentials in Redis | Multi-replica Kubernetes deployments |
|
||||
| `auto` | Use Redis if available, fallback to file | Flexible deployments (default) |
|
||||
|
||||
### Multi-Replica Deployment
|
||||
|
||||
Without Redis, disable replay detection:
|
||||
|
||||
@@ -353,6 +353,8 @@ allowPrivateIPAddresses: true # Required for private IPs
|
||||
- Roles: User Client Role mapper with "Add to ID token" enabled
|
||||
- Groups: Group Membership mapper with "Add to ID token" enabled
|
||||
|
||||
See [KEYCLOAK_SETUP_GUIDE.md](KEYCLOAK_SETUP_GUIDE.md) for detailed step-by-step setup instructions, mapper configuration, troubleshooting, and performance optimization.
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
+43
-1
@@ -193,7 +193,7 @@
|
||||
</div>
|
||||
<div>
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-1">Dynamic Registration</h3>
|
||||
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration for automatic client setup without manual configuration</p>
|
||||
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration with Redis storage support for multi-replica deployments</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -862,6 +862,48 @@ spec:
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Dynamic Client Registration (RFC 7591)</h3>
|
||||
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Automatically register your application with the OIDC provider. Supports Redis storage for multi-replica deployments:</p>
|
||||
<div class="overflow-x-auto mb-4">
|
||||
<table class="w-full text-sm">
|
||||
<thead>
|
||||
<tr class="border-b border-gray-200 dark:border-gray-700">
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Parameter</th>
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Default</th>
|
||||
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Description</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody class="text-gray-600 dark:text-gray-400">
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.enabled</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable dynamic client registration</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.persistCredentials</code></td>
|
||||
<td class="py-2 px-3">true</td>
|
||||
<td class="py-2 px-3">Persist registered credentials across restarts</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.storageBackend</code></td>
|
||||
<td class="py-2 px-3">auto</td>
|
||||
<td class="py-2 px-3">Storage backend: <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">file</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis</code>, or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">auto</code> (uses Redis if available)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.redisKeyPrefix</code></td>
|
||||
<td class="py-2 px-3">dcr:creds:</td>
|
||||
<td class="py-2 px-3">Redis key prefix for DCR credentials</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.clientMetadata.redirect_uris</code></td>
|
||||
<td class="py-2 px-3">-</td>
|
||||
<td class="py-2 px-3">Redirect URIs for the registered client (required)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div class="glass p-6 rounded-xl">
|
||||
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Example: Security Headers with CORS</h3>
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ type DynamicClientRegistrar struct {
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
registrationResponse *ClientRegistrationResponse
|
||||
store DCRCredentialsStore // Storage backend for credentials
|
||||
providerURL string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -73,8 +74,37 @@ func NewDynamicClientRegistrar(
|
||||
}
|
||||
}
|
||||
|
||||
// NewDynamicClientRegistrarWithStore creates a new dynamic client registrar with a specific storage backend
|
||||
func NewDynamicClientRegistrarWithStore(
|
||||
httpClient *http.Client,
|
||||
logger *Logger,
|
||||
dcrConfig *DynamicClientRegistrationConfig,
|
||||
providerURL string,
|
||||
store DCRCredentialsStore,
|
||||
) *DynamicClientRegistrar {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
return &DynamicClientRegistrar{
|
||||
httpClient: httpClient,
|
||||
logger: logger,
|
||||
config: dcrConfig,
|
||||
providerURL: providerURL,
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// SetStore sets the credentials store for the registrar
|
||||
// This allows setting the store after creation when the cache manager is available
|
||||
func (r *DynamicClientRegistrar) SetStore(store DCRCredentialsStore) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.store = store
|
||||
}
|
||||
|
||||
// RegisterClient performs dynamic client registration with the OIDC provider
|
||||
// It first attempts to load existing credentials from a file if persistence is enabled,
|
||||
// It first attempts to load existing credentials from storage if persistence is enabled,
|
||||
// then registers a new client if no valid credentials exist.
|
||||
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
|
||||
if r.config == nil || !r.config.Enabled {
|
||||
@@ -83,10 +113,13 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
|
||||
|
||||
// Try to load existing credentials if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if resp, err := r.loadCredentials(); err == nil && resp != nil {
|
||||
resp, err := r.loadCredentialsFromStore(ctx)
|
||||
if err != nil {
|
||||
r.logger.Debugf("Failed to load credentials from store: %v", err)
|
||||
} else if resp != nil {
|
||||
// Check if credentials are still valid (not expired)
|
||||
if r.areCredentialsValid(resp) {
|
||||
r.logger.Info("Loaded existing client credentials from file")
|
||||
r.logger.Info("Loaded existing client credentials from storage")
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = resp
|
||||
r.mu.Unlock()
|
||||
@@ -179,7 +212,7 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
|
||||
|
||||
// Persist credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist client credentials: %v", err)
|
||||
// Don't fail registration if persistence fails
|
||||
}
|
||||
@@ -315,7 +348,44 @@ func (r *DynamicClientRegistrar) credentialsFilePath() string {
|
||||
return "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file
|
||||
// loadCredentialsFromStore loads client credentials from the configured storage backend
|
||||
// Falls back to legacy file-based loading if no store is configured
|
||||
func (r *DynamicClientRegistrar) loadCredentialsFromStore(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Load(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based loading
|
||||
return r.loadCredentials()
|
||||
}
|
||||
|
||||
// saveCredentialsToStore persists client credentials to the configured storage backend
|
||||
// Falls back to legacy file-based saving if no store is configured
|
||||
func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, resp *ClientRegistrationResponse) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Save(ctx, r.providerURL, resp)
|
||||
}
|
||||
// Fallback to legacy file-based saving
|
||||
return r.saveCredentials(resp)
|
||||
}
|
||||
|
||||
// deleteCredentialsFromStore removes credentials from the configured storage backend
|
||||
// Falls back to legacy file-based deletion if no store is configured
|
||||
func (r *DynamicClientRegistrar) deleteCredentialsFromStore(ctx context.Context) error {
|
||||
// Use store if available
|
||||
if r.store != nil {
|
||||
return r.store.Delete(ctx, r.providerURL)
|
||||
}
|
||||
// Fallback to legacy file-based deletion
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
@@ -333,7 +403,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCredentials loads client credentials from a file
|
||||
// loadCredentials loads client credentials from a file (legacy method)
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
@@ -420,7 +490,7 @@ func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
if err := r.saveCredentialsToStore(ctx, ®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -527,11 +597,10 @@ func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) e
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials file if persistence is enabled
|
||||
// Remove credentials from storage if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
r.logger.Errorf("Failed to remove credentials file: %v", err)
|
||||
if err := r.deleteCredentialsFromStore(ctx); err != nil {
|
||||
r.logger.Errorf("Failed to remove credentials from storage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+224
-197
@@ -2,20 +2,27 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Default configuration values
|
||||
const (
|
||||
defaultShardCount = 256
|
||||
defaultMaxSize = int64(10000)
|
||||
defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB
|
||||
defaultCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
value interface{}
|
||||
element *list.Element
|
||||
element interface{} // *list.Element, using interface{} to avoid import cycle
|
||||
key string
|
||||
accessCount int64
|
||||
size int64
|
||||
@@ -29,56 +36,89 @@ func (item *memoryCacheItem) isExpired() bool {
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage
|
||||
// The sharded design reduces lock contention by partitioning keys across multiple shards,
|
||||
// each with its own lock.
|
||||
type MemoryCacheBackend struct {
|
||||
shards []*cacheShard
|
||||
startTime time.Time
|
||||
lastErrorTime time.Time
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
cleanupDone chan bool
|
||||
cleanupDone chan struct{}
|
||||
cleanupTicker *time.Ticker
|
||||
evictionPolicy string
|
||||
lastError string
|
||||
currentMemory int64
|
||||
misses atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
sets atomic.Int64
|
||||
hits atomic.Int64
|
||||
shardCount uint32
|
||||
shardMask uint32
|
||||
maxSize int64
|
||||
currentSize int64
|
||||
maxMemory int64
|
||||
cleanupInterval time.Duration
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool
|
||||
|
||||
// Global stats (aggregated from shards)
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// State
|
||||
closed atomic.Bool
|
||||
mu sync.RWMutex // For global operations like stats and error tracking
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
// NewMemoryCacheBackend creates a new sharded memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 10000 // Default to 10k items
|
||||
maxSize = defaultMaxSize
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = 100 * 1024 * 1024 // Default to 100MB
|
||||
maxMemory = defaultMaxMemory
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = 5 * time.Minute
|
||||
cleanupInterval = defaultCleanupInterval
|
||||
}
|
||||
|
||||
shardCount := uint32(defaultShardCount)
|
||||
|
||||
// For very small caches, reduce shard count to maintain sensible per-shard limits
|
||||
// Ensure each shard can hold at least 2 items for proper LRU behavior
|
||||
for shardCount > 1 && maxSize/int64(shardCount) < 2 {
|
||||
shardCount /= 2
|
||||
}
|
||||
if shardCount < 1 {
|
||||
shardCount = 1
|
||||
}
|
||||
|
||||
// Per-shard limits are soft hints; global limits are enforced
|
||||
// Give shards 2x the average to allow for uneven distribution
|
||||
shardMaxSize := (maxSize * 2) / int64(shardCount)
|
||||
if shardMaxSize < 4 {
|
||||
shardMaxSize = 4
|
||||
}
|
||||
shardMaxMemory := (maxMemory * 2) / int64(shardCount)
|
||||
if shardMaxMemory < 4096 {
|
||||
shardMaxMemory = 4096 // Minimum 4KB per shard
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
shards: make([]*cacheShard, shardCount),
|
||||
shardCount: shardCount,
|
||||
shardMask: shardCount - 1, // For fast modulo with power-of-2
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
evictionPolicy: "lru",
|
||||
cleanupDone: make(chan bool),
|
||||
cleanupDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize shards
|
||||
for i := uint32(0); i < shardCount; i++ {
|
||||
m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
@@ -88,6 +128,12 @@ func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.
|
||||
return m
|
||||
}
|
||||
|
||||
// getShard returns the shard for a given key
|
||||
func (m *MemoryCacheBackend) getShard(key string) *cacheShard {
|
||||
hash := fnv32(key)
|
||||
return m.shards[hash&m.shardMask]
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
@@ -100,20 +146,19 @@ func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from the cache
|
||||
// cleanupExpired removes all expired items from all shards
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var keysToDelete []string
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range keysToDelete {
|
||||
m.deleteItemLocked(key)
|
||||
totalRemoved := 0
|
||||
for _, shard := range m.shards {
|
||||
totalRemoved += shard.cleanup()
|
||||
}
|
||||
|
||||
if totalRemoved > 0 {
|
||||
m.evictions.Add(int64(totalRemoved))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,35 +175,23 @@ func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{},
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
shard := m.getShard(key)
|
||||
value, exists, expired := shard.get(key)
|
||||
|
||||
if expired {
|
||||
// Clean up expired item
|
||||
shard.delete(key)
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.Lock()
|
||||
m.deleteItemLocked(key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
// Update access time and count
|
||||
m.mu.Lock()
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
// Move to front of LRU list
|
||||
if m.evictionPolicy == "lru" && item.element != nil {
|
||||
m.lruList.MoveToFront(item.element)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return item.value, nil
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
@@ -174,113 +207,105 @@ func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interfac
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size (simplified estimation)
|
||||
// Calculate item size
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Enforce global limits before adding new item
|
||||
m.enforceGlobalLimits(itemSize)
|
||||
|
||||
// Check if we need to evict items
|
||||
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
|
||||
m.evictLocked()
|
||||
}
|
||||
|
||||
// Check if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
m.currentMemory -= oldItem.size
|
||||
if oldItem.element != nil {
|
||||
m.lruList.Remove(oldItem.element)
|
||||
}
|
||||
} else {
|
||||
m.currentSize++
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = now.Add(ttl)
|
||||
expiresAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: itemSize,
|
||||
}
|
||||
shard := m.getShard(key)
|
||||
shard.set(key, value, expiresAt, itemSize)
|
||||
|
||||
// Add to LRU list
|
||||
if m.evictionPolicy == "lru" {
|
||||
item.element = m.lruList.PushFront(item)
|
||||
}
|
||||
|
||||
m.items[key] = item
|
||||
m.currentMemory += itemSize
|
||||
m.sets.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// enforceGlobalLimits ensures global size and memory limits are respected
|
||||
// by evicting from shards when necessary
|
||||
func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) {
|
||||
// Check and enforce size limit
|
||||
for {
|
||||
totalSize, totalMemory := m.getGlobalStats()
|
||||
|
||||
needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize
|
||||
needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory
|
||||
|
||||
if !needsSizeEviction && !needsMemoryEviction {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the shard with the most items and evict from it
|
||||
evicted := m.evictFromLargestShard()
|
||||
if !evicted {
|
||||
break // No more items to evict
|
||||
}
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// getGlobalStats returns the total size and memory usage across all shards
|
||||
func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) {
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// evictFromLargestShard evicts the globally oldest item across all shards
|
||||
// This provides true LRU behavior even with sharding
|
||||
func (m *MemoryCacheBackend) evictFromLargestShard() bool {
|
||||
var oldestShard *cacheShard
|
||||
var oldestTime time.Time
|
||||
|
||||
for _, shard := range m.shards {
|
||||
accessTime := shard.getOldestAccessTime()
|
||||
// Skip empty shards
|
||||
if accessTime.IsZero() {
|
||||
continue
|
||||
}
|
||||
// Find the shard with the oldest (earliest) access time
|
||||
if oldestShard == nil || accessTime.Before(oldestTime) {
|
||||
oldestTime = accessTime
|
||||
oldestShard = shard
|
||||
}
|
||||
}
|
||||
|
||||
if oldestShard == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return oldestShard.evictOne()
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.items[key]; !exists {
|
||||
return nil
|
||||
shard := m.getShard(key)
|
||||
if shard.delete(key) {
|
||||
m.deletes.Add(1)
|
||||
}
|
||||
|
||||
m.deleteItemLocked(key)
|
||||
m.deletes.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
|
||||
if item, exists := m.items[key]; exists {
|
||||
m.currentMemory -= item.size
|
||||
m.currentSize--
|
||||
if item.element != nil {
|
||||
m.lruList.Remove(item.element)
|
||||
}
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// evictLocked evicts items based on the eviction policy (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) evictLocked() {
|
||||
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
|
||||
// Evict least recently used item
|
||||
element := m.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
m.deleteItemLocked(item.key)
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !item.isExpired(), nil
|
||||
shard := m.getShard(key)
|
||||
return shard.exists(key), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
@@ -289,13 +314,9 @@ func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryCacheItem)
|
||||
m.lruList = list.New()
|
||||
m.currentSize = 0
|
||||
m.currentMemory = 0
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -306,29 +327,28 @@ func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range m.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
var allKeys []string
|
||||
for _, shard := range m.shards {
|
||||
keys := shard.keys(pattern)
|
||||
allKeys = append(allKeys, keys...)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
return allKeys, nil
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
// Size returns the total number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
var total int64
|
||||
for _, shard := range m.shards {
|
||||
size, _ := shard.stats()
|
||||
total += size
|
||||
}
|
||||
|
||||
return m.currentSize, nil
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
@@ -337,24 +357,13 @@ func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
shard := m.getShard(key)
|
||||
ttl, exists := shard.ttl(key)
|
||||
if !exists {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, nil // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
@@ -363,20 +372,11 @@ func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Du
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
shard := m.getShard(key)
|
||||
if !shard.expire(key, ttl) {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -386,6 +386,14 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
// Aggregate stats from all shards
|
||||
var totalSize, totalMemory int64
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
@@ -409,9 +417,9 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: m.currentSize,
|
||||
CurrentSize: totalSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: m.currentMemory,
|
||||
MemoryUsage: totalMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
@@ -438,10 +446,10 @@ func (m *MemoryCacheBackend) Close() error {
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
m.mu.Lock()
|
||||
m.items = nil
|
||||
m.lruList = nil
|
||||
m.mu.Unlock()
|
||||
// Clear all shards
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -474,12 +482,28 @@ func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
// GetShardCount returns the number of shards (for testing/monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardCount() uint32 {
|
||||
return m.shardCount
|
||||
}
|
||||
|
||||
// GetShardStats returns per-shard statistics (for monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 {
|
||||
stats := make([]map[string]int64, m.shardCount)
|
||||
for i, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
stats[i] = map[string]int64{
|
||||
"size": size,
|
||||
"memory": memory,
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
// This is a simplified estimation
|
||||
// In production, you might want to use a more accurate method
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
@@ -502,7 +526,10 @@ func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching - in production, use a proper glob library
|
||||
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
|
||||
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
|
||||
// Simplified pattern matching
|
||||
if len(pattern) > 0 && pattern[0] == '*' {
|
||||
suffix := pattern[1:]
|
||||
return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix
|
||||
}
|
||||
return key == pattern
|
||||
}
|
||||
|
||||
+290
@@ -0,0 +1,290 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cacheShard represents a single shard of the sharded cache
|
||||
// Each shard has its own lock for reduced contention
|
||||
type cacheShard struct {
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
mu sync.RWMutex
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
size int64
|
||||
memoryUsed int64
|
||||
}
|
||||
|
||||
// newCacheShard creates a new cache shard
|
||||
func newCacheShard(maxSize, maxMemory int64) *cacheShard {
|
||||
return &cacheShard{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a value from this shard
|
||||
// Returns: value, exists, expired
|
||||
func (s *cacheShard) get(key string) (interface{}, bool, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
return nil, true, true // exists but expired
|
||||
}
|
||||
|
||||
// Update access time and LRU position under write lock
|
||||
s.mu.Lock()
|
||||
// Re-check item exists (could have been deleted)
|
||||
item, exists = s.items[key]
|
||||
if exists && !item.isExpired() {
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.MoveToFront(elem)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
return item.value, true, false
|
||||
}
|
||||
|
||||
// set stores a value in this shard
|
||||
func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if s.maxSize > 0 && s.size >= s.maxSize {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
|
||||
// Remove old item if exists
|
||||
if oldItem, exists := s.items[key]; exists {
|
||||
s.memoryUsed -= oldItem.size
|
||||
if elem, ok := oldItem.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
s.size--
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: size,
|
||||
}
|
||||
|
||||
item.element = s.lruList.PushFront(item)
|
||||
s.items[key] = item
|
||||
s.size++
|
||||
s.memoryUsed += size
|
||||
}
|
||||
|
||||
// delete removes a key from this shard
|
||||
// Returns true if the key was deleted
|
||||
func (s *cacheShard) delete(key string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
|
||||
// exists checks if a key exists (and is not expired)
|
||||
func (s *cacheShard) exists(key string) bool {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return !item.isExpired()
|
||||
}
|
||||
|
||||
// ttl returns the remaining TTL for a key
|
||||
func (s *cacheShard) ttl(key string) (time.Duration, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, true // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return remaining, true
|
||||
}
|
||||
|
||||
// expire updates the TTL for an existing key
|
||||
func (s *cacheShard) expire(key string, ttl time.Duration) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return false
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// keys returns all non-expired keys matching the pattern
|
||||
func (s *cacheShard) keys(pattern string) []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range s.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// clear removes all items from this shard
|
||||
func (s *cacheShard) clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.items = make(map[string]*memoryCacheItem)
|
||||
s.lruList.Init()
|
||||
s.size = 0
|
||||
s.memoryUsed = 0
|
||||
}
|
||||
|
||||
// cleanup removes expired items
|
||||
// Returns the number of items removed
|
||||
func (s *cacheShard) cleanup() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var toRemove []*memoryCacheItem
|
||||
for _, item := range s.items {
|
||||
if item.isExpired() {
|
||||
toRemove = append(toRemove, item)
|
||||
}
|
||||
}
|
||||
|
||||
for _, item := range toRemove {
|
||||
s.deleteItemLocked(item)
|
||||
}
|
||||
|
||||
return len(toRemove)
|
||||
}
|
||||
|
||||
// stats returns statistics for this shard
|
||||
func (s *cacheShard) stats() (size, memory int64) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.size, s.memoryUsed
|
||||
}
|
||||
|
||||
// deleteItemLocked removes an item (must be called with lock held)
|
||||
func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) {
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
delete(s.items, item.key)
|
||||
s.size--
|
||||
s.memoryUsed -= item.size
|
||||
}
|
||||
|
||||
// evictLRULocked evicts the least recently used item (must be called with lock held)
|
||||
func (s *cacheShard) evictLRULocked() bool {
|
||||
if s.lruList.Len() == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// evictOne evicts one item from this shard (for global limit enforcement)
|
||||
func (s *cacheShard) evictOne() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.evictLRULocked()
|
||||
}
|
||||
|
||||
// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard
|
||||
// Returns zero time if shard is empty
|
||||
func (s *cacheShard) getOldestAccessTime() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.lruList.Len() == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
return item.accessedAt
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// fnv32 computes FNV-1a hash of a string
|
||||
// This is a fast, well-distributed hash function
|
||||
func fnv32(key string) uint32 {
|
||||
const (
|
||||
offset32 = uint32(2166136261)
|
||||
prime32 = uint32(16777619)
|
||||
)
|
||||
|
||||
hash := offset32
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash ^= uint32(key[i])
|
||||
hash *= prime32
|
||||
}
|
||||
return hash
|
||||
}
|
||||
+283
@@ -0,0 +1,283 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestShardedCache_ShardDistribution tests that keys are distributed across shards
|
||||
func TestShardedCache_ShardDistribution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a cache with large enough size to have multiple shards
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add many items to see distribution
|
||||
numItems := 1000
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("dist-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("dist-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Check that items are distributed across multiple shards
|
||||
shardStats := backend.MemoryCacheBackend.GetShardStats()
|
||||
nonEmptyShards := 0
|
||||
for _, stat := range shardStats {
|
||||
if stat["size"] > 0 {
|
||||
nonEmptyShards++
|
||||
}
|
||||
}
|
||||
|
||||
// With good hash distribution, we should have items in multiple shards
|
||||
assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards")
|
||||
}
|
||||
|
||||
// TestShardedCache_ShardCount tests that shard count adapts to cache size
|
||||
func TestShardedCache_ShardCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
maxSize int
|
||||
expectLowShards bool
|
||||
}{
|
||||
{5, true}, // Very small cache should have fewer shards
|
||||
{100, true}, // Small cache should have fewer shards
|
||||
{10000, false}, // Large cache should have default shards
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = tt.maxSize
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
shardCount := backend.MemoryCacheBackend.GetShardCount()
|
||||
|
||||
if tt.expectLowShards {
|
||||
assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards")
|
||||
} else {
|
||||
assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key
|
||||
func TestShardedCache_ConcurrentSameKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "concurrent-same-key"
|
||||
initialValue := []byte("initial-value")
|
||||
|
||||
err = backend.Set(ctx, key, initialValue, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 50
|
||||
iterations := 100
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of reads and writes
|
||||
if j%3 == 0 {
|
||||
newValue := []byte(fmt.Sprintf("value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, newValue, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Key should still exist
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained
|
||||
func TestShardedCache_GlobalLRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a small cache to force eviction
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
// Small delay to ensure different access times
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Access some items to make them recently used
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Add more items to trigger eviction
|
||||
for i := 10; i < 15; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Recently accessed items (5-9) should still exist
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item %d should exist", i)
|
||||
}
|
||||
|
||||
// Check eviction stats
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestShardedCache_StatsAggregation tests that stats are aggregated correctly
|
||||
func TestShardedCache_StatsAggregation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items to multiple shards
|
||||
numItems := 100
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("stats-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Read some items
|
||||
for i := 0; i < numItems/2; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Read non-existent items
|
||||
for i := 0; i < 10; i++ {
|
||||
backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i))
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
|
||||
// Verify stats
|
||||
assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match")
|
||||
assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match")
|
||||
assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match")
|
||||
assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match")
|
||||
|
||||
// Verify hit rate
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
expectedHitRate := float64(numItems/2) / float64(numItems/2+10)
|
||||
assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match")
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_Parallel benchmarks parallel access
|
||||
func BenchmarkShardedCache_Parallel(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10000; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("bench-key-%d", i%10000)
|
||||
backend.Get(ctx, key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_MixedOps benchmarks mixed operations
|
||||
func BenchmarkShardedCache_MixedOps(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("mixed-key-%d", i%1000)
|
||||
if i%3 == 0 {
|
||||
value := []byte(fmt.Sprintf("mixed-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
} else {
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
+20
-30
@@ -45,21 +45,11 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get the item directly to check TTL
|
||||
m.MemoryCacheBackend.mu.RLock()
|
||||
item, exists := m.MemoryCacheBackend.items[key]
|
||||
m.MemoryCacheBackend.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
if !item.expiresAt.IsZero() {
|
||||
ttl = time.Until(item.expiresAt)
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
// Get TTL using the TTL method
|
||||
ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key)
|
||||
if ttlErr != nil {
|
||||
// If we can't get TTL, still return the value with 0 TTL
|
||||
ttl = 0
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
@@ -68,8 +58,7 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, we might need to handle other types
|
||||
// For now, we'll just return an error
|
||||
// If it's not already []byte, return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
@@ -123,19 +112,20 @@ func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"shard_count": m.MemoryCacheBackend.GetShardCount(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+107
-11
@@ -431,39 +431,135 @@ func isRetryableError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetMany stores multiple values in Redis (batch operation)
|
||||
// SetMany stores multiple values in Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially (can be optimized with pipelining later)
|
||||
for key, value := range items {
|
||||
if err := r.Set(ctx, key, value, ttl); err != nil {
|
||||
return err
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For single items, use regular Set
|
||||
if len(items) == 1 {
|
||||
for key, value := range items {
|
||||
return r.Set(ctx, key, value, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all SET commands
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
|
||||
for key, value := range items {
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
if ttl > 0 {
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs
|
||||
pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs
|
||||
pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
pipeline.Queue("SET", prefixedKey, string(value))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pipeline SetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Check responses for errors (each should be "OK")
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
continue
|
||||
}
|
||||
if str, ok := resp.(string); ok && str == "OK" {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values from Redis
|
||||
// GetMany retrieves multiple values from Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
result := make(map[string][]byte)
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially
|
||||
for _, key := range keys {
|
||||
value, _, exists, err := r.Get(ctx, key)
|
||||
// For single key, use regular Get
|
||||
if len(keys) == 1 {
|
||||
result := make(map[string][]byte)
|
||||
value, _, exists, err := r.Get(ctx, keys[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
result[key] = value
|
||||
result[keys[0]] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all GET commands
|
||||
prefixedKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
prefixedKeys[i] = r.prefixKey(key)
|
||||
pipeline.Queue("GET", prefixedKeys[i])
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pipeline GetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Process responses
|
||||
result := make(map[string][]byte)
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
// Key doesn't exist
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
// Invalid response, skip this key
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
result[keys[i]] = []byte(value)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
+461
@@ -0,0 +1,461 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setupTestRedis creates a miniredis instance for testing
|
||||
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "test:",
|
||||
PoolSize: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
backend.Close()
|
||||
})
|
||||
|
||||
return mr, backend
|
||||
}
|
||||
|
||||
// TestPipeline_Basic tests basic pipeline functionality
|
||||
func TestPipeline_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.Addr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SingleCommand", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "single-key", "single-value")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
})
|
||||
|
||||
t.Run("MultipleCommands", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "key1", "value1")
|
||||
pipeline.Queue("SET", "key2", "value2")
|
||||
pipeline.Queue("SET", "key3", "value3")
|
||||
pipeline.Queue("GET", "key1")
|
||||
pipeline.Queue("GET", "key2")
|
||||
pipeline.Queue("GET", "key3")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 6)
|
||||
|
||||
// First 3 are SET responses
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
assert.Equal(t, "OK", responses[1])
|
||||
assert.Equal(t, "OK", responses[2])
|
||||
|
||||
// Last 3 are GET responses
|
||||
assert.Equal(t, "value1", responses[3])
|
||||
assert.Equal(t, "value2", responses[4])
|
||||
assert.Equal(t, "value3", responses[5])
|
||||
})
|
||||
|
||||
t.Run("EmptyPipeline", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, responses)
|
||||
})
|
||||
|
||||
t.Run("NilResponses", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("GET", "nonexistent-key")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Nil(t, responses[0])
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_SetMany tests pipelined SetMany
|
||||
func TestPipeline_SetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetManyItems", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key %s should exist", key)
|
||||
assert.Equal(t, expectedValue, value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetManyEmpty", func(t *testing.T) {
|
||||
err := backend.SetMany(ctx, map[string][]byte{}, time.Minute)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetManySingleItem", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"single-setmany": []byte("single-value"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
value, _, exists, err := backend.Get(ctx, "single-setmany")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("single-value"), value)
|
||||
})
|
||||
|
||||
t.Run("SetManyNoTTL", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"nottl-key1": []byte("value1"),
|
||||
"nottl-key2": []byte("value2"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Keys should exist
|
||||
for key := range items {
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_GetMany tests pipelined GetMany
|
||||
func TestPipeline_GetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("getmany-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("GetManyExisting", func(t *testing.T) {
|
||||
keys := make([]string, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
keys[i] = fmt.Sprintf("getmany-key-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 10)
|
||||
|
||||
for i, key := range keys {
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyMixed", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"getmany-key-0", // exists
|
||||
"nonexistent-key-1", // doesn't exist
|
||||
"getmany-key-2", // exists
|
||||
"nonexistent-key-2", // doesn't exist
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
|
||||
assert.Equal(t, []byte("value-0"), results["getmany-key-0"])
|
||||
assert.Equal(t, []byte("value-2"), results["getmany-key-2"])
|
||||
assert.NotContains(t, results, "nonexistent-key-1")
|
||||
assert.NotContains(t, results, "nonexistent-key-2")
|
||||
})
|
||||
|
||||
t.Run("GetManyEmpty", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, results)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
|
||||
t.Run("GetManySingleKey", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{"getmany-key-5"})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, []byte("value-5"), results["getmany-key-5"])
|
||||
})
|
||||
|
||||
t.Run("GetManyAllNonexistent", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"nonexistent-1",
|
||||
"nonexistent-2",
|
||||
"nonexistent-3",
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_LargeBatch tests pipelining with large batches
|
||||
func TestPipeline_LargeBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany100Items", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify random samples
|
||||
for _, i := range []int{0, 25, 50, 75, 99} {
|
||||
key := fmt.Sprintf("large-batch-%d", i)
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany100Items", func(t *testing.T) {
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("large-batch-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 100)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_Stats tests that stats are tracked correctly with pipelining
|
||||
func TestPipeline_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Set some items
|
||||
items := map[string][]byte{
|
||||
"stats-key-1": []byte("value1"),
|
||||
"stats-key-2": []byte("value2"),
|
||||
}
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get items (some exist, some don't)
|
||||
keys := []string{
|
||||
"stats-key-1",
|
||||
"stats-key-2",
|
||||
"stats-key-nonexistent",
|
||||
}
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
// Check stats
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
|
||||
assert.Equal(t, int64(2), hits, "Should have 2 hits")
|
||||
assert.Equal(t, int64(1), misses, "Should have 1 miss")
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining
|
||||
func BenchmarkPipeline_SetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining
|
||||
func BenchmarkPipeline_GetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
// Prepare keys
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("bench-key-%d", i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations
|
||||
func BenchmarkPipeline_VsSequential(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
keys := make([]string, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
key := fmt.Sprintf("compare-key-%d", i)
|
||||
keys[i] = key
|
||||
items[key] = []byte(fmt.Sprintf("compare-value-%d", i))
|
||||
}
|
||||
|
||||
b.Run("Pipelined-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for key, value := range items {
|
||||
_ = backend.Set(ctx, key, value, time.Minute)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Pre-populate for get benchmarks
|
||||
_ = backend.SetMany(ctx, items, time.Hour)
|
||||
|
||||
b.Run("Pipelined-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, key := range keys {
|
||||
_, _, _, _ = backend.Get(ctx, key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
+117
@@ -336,3 +336,120 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
_, err := conn.Do("PING")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Pipeline represents a Redis pipeline for batch operations
|
||||
// It queues multiple commands and executes them in a single round-trip
|
||||
type Pipeline struct {
|
||||
conn *RedisConn
|
||||
commands []pipelineCommand
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// pipelineCommand represents a single command in the pipeline
|
||||
type pipelineCommand struct {
|
||||
command string
|
||||
args []string
|
||||
}
|
||||
|
||||
// NewPipeline creates a new pipeline for the connection
|
||||
func (c *RedisConn) NewPipeline() *Pipeline {
|
||||
return &Pipeline{
|
||||
conn: c,
|
||||
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
|
||||
}
|
||||
}
|
||||
|
||||
// Queue adds a command to the pipeline
|
||||
func (p *Pipeline) Queue(command string, args ...string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.commands = append(p.commands, pipelineCommand{
|
||||
command: command,
|
||||
args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute sends all queued commands and returns all responses
|
||||
// Returns a slice of responses in the same order as commands were queued
|
||||
func (p *Pipeline) Execute() ([]interface{}, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.commands) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.conn.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.conn.mu.Lock()
|
||||
defer p.conn.mu.Unlock()
|
||||
|
||||
// Set write timeout for all commands
|
||||
if p.conn.writeTimeout > 0 {
|
||||
// Use longer timeout for batch operations
|
||||
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second // Cap at 30 seconds
|
||||
}
|
||||
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Write all commands (pipelining - send all before reading any responses)
|
||||
writer := NewRESPWriter(p.conn.conn)
|
||||
for _, cmd := range p.commands {
|
||||
cmdArgs := append([]string{cmd.command}, cmd.args...)
|
||||
if err := writer.WriteCommand(cmdArgs...); err != nil {
|
||||
writer.Release()
|
||||
p.conn.closed.Store(true)
|
||||
return nil, fmt.Errorf("pipeline write error: %w", err)
|
||||
}
|
||||
}
|
||||
writer.Release()
|
||||
|
||||
// Set read timeout for all responses
|
||||
if p.conn.readTimeout > 0 {
|
||||
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Read all responses
|
||||
responses := make([]interface{}, len(p.commands))
|
||||
reader := NewRESPReader(p.conn.conn)
|
||||
defer reader.Release()
|
||||
|
||||
for i := range p.commands {
|
||||
resp, err := reader.ReadResponse()
|
||||
if err != nil {
|
||||
// For nil responses, store nil instead of erroring
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
responses[i] = nil
|
||||
continue
|
||||
}
|
||||
p.conn.closed.Store(true)
|
||||
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
|
||||
}
|
||||
responses[i] = resp
|
||||
}
|
||||
|
||||
return responses, nil
|
||||
}
|
||||
|
||||
// Clear resets the pipeline for reuse
|
||||
func (p *Pipeline) Clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.commands = p.commands[:0]
|
||||
}
|
||||
|
||||
// Len returns the number of queued commands
|
||||
func (p *Pipeline) Len() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.commands)
|
||||
}
|
||||
|
||||
+183
@@ -0,0 +1,183 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SingleflightCache wraps a CacheBackend with singleflight deduplication
|
||||
// to prevent thundering herd problems when multiple concurrent requests
|
||||
// try to fetch the same uncached key.
|
||||
type SingleflightCache struct {
|
||||
backend CacheBackend
|
||||
mu sync.Mutex
|
||||
calls map[string]*singleflightCall
|
||||
|
||||
// Metrics
|
||||
deduplicatedCalls atomic.Int64
|
||||
totalCalls atomic.Int64
|
||||
}
|
||||
|
||||
// singleflightCall represents an in-flight or completed fetch call
|
||||
type singleflightCall struct {
|
||||
wg sync.WaitGroup
|
||||
val []byte
|
||||
ttl time.Duration
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
// NewSingleflightCache creates a new singleflight-wrapped cache backend
|
||||
func NewSingleflightCache(backend CacheBackend) *SingleflightCache {
|
||||
return &SingleflightCache{
|
||||
backend: backend,
|
||||
calls: make(map[string]*singleflightCall),
|
||||
}
|
||||
}
|
||||
|
||||
// Fetcher is a function type that fetches data when cache misses
|
||||
type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error)
|
||||
|
||||
// GetOrFetch retrieves a value from cache or calls the fetcher exactly once
|
||||
// per key when there's a cache miss. Concurrent calls for the same key will
|
||||
// wait for the first call to complete and share its result.
|
||||
func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) {
|
||||
s.totalCalls.Add(1)
|
||||
|
||||
// Try cache first
|
||||
value, _, exists, err := s.backend.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Cache miss - use singleflight
|
||||
s.mu.Lock()
|
||||
|
||||
// Check if there's already an in-flight call for this key
|
||||
if call, ok := s.calls[key]; ok {
|
||||
s.mu.Unlock()
|
||||
s.deduplicatedCalls.Add(1)
|
||||
|
||||
// Wait for the in-flight call to complete
|
||||
call.wg.Wait()
|
||||
|
||||
// Check context cancellation
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Create new call
|
||||
call := &singleflightCall{}
|
||||
call.wg.Add(1)
|
||||
s.calls[key] = call
|
||||
s.mu.Unlock()
|
||||
|
||||
// Execute the fetcher
|
||||
call.val, call.ttl, call.err = fetcher(ctx)
|
||||
call.done = true
|
||||
|
||||
// If successful, store in cache
|
||||
if call.err == nil && call.val != nil {
|
||||
// Use a background context for cache storage to ensure it completes
|
||||
// even if the original context is cancelled
|
||||
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Signal waiting goroutines
|
||||
call.wg.Done()
|
||||
|
||||
// Clean up the call from the map after a short delay
|
||||
// This allows late arrivals to still benefit from the result
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
s.mu.Lock()
|
||||
if c, ok := s.calls[key]; ok && c == call {
|
||||
delete(s.calls, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the underlying cache backend
|
||||
func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
return s.backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Set stores a value in the underlying cache backend
|
||||
func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return s.backend.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// Delete removes a key from the underlying cache backend
|
||||
func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the underlying cache backend
|
||||
func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the underlying cache backend
|
||||
func (s *SingleflightCache) Clear(ctx context.Context) error {
|
||||
return s.backend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics including singleflight metrics
|
||||
func (s *SingleflightCache) GetStats() map[string]interface{} {
|
||||
stats := s.backend.GetStats()
|
||||
|
||||
// Add singleflight-specific stats
|
||||
totalCalls := s.totalCalls.Load()
|
||||
deduped := s.deduplicatedCalls.Load()
|
||||
|
||||
stats["singleflight_total_calls"] = totalCalls
|
||||
stats["singleflight_deduplicated"] = deduped
|
||||
if totalCalls > 0 {
|
||||
stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls)
|
||||
} else {
|
||||
stats["singleflight_dedup_rate"] = float64(0)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
stats["singleflight_inflight"] = len(s.calls)
|
||||
s.mu.Unlock()
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend
|
||||
func (s *SingleflightCache) Close() error {
|
||||
return s.backend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (s *SingleflightCache) Ping(ctx context.Context) error {
|
||||
return s.backend.Ping(ctx)
|
||||
}
|
||||
|
||||
// GetBackend returns the underlying cache backend
|
||||
func (s *SingleflightCache) GetBackend() CacheBackend {
|
||||
return s.backend
|
||||
}
|
||||
|
||||
// ResetStats resets the singleflight statistics
|
||||
func (s *SingleflightCache) ResetStats() {
|
||||
s.totalCalls.Store(0)
|
||||
s.deduplicatedCalls.Store(0)
|
||||
}
|
||||
|
||||
// Ensure SingleflightCache implements CacheBackend
|
||||
var _ CacheBackend = (*SingleflightCache)(nil)
|
||||
+510
@@ -0,0 +1,510 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality
|
||||
func TestSingleflightCache_BasicGetOrFetch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CacheHit", func(t *testing.T) {
|
||||
key := "existing-key"
|
||||
value := []byte("existing-value")
|
||||
|
||||
// Pre-populate cache
|
||||
err := cache.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return []byte("fetched-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, value, result)
|
||||
assert.False(t, fetchCalled, "Fetcher should not be called on cache hit")
|
||||
})
|
||||
|
||||
t.Run("CacheMiss", func(t *testing.T) {
|
||||
key := "missing-key"
|
||||
expectedValue := []byte("fetched-value")
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, result)
|
||||
assert.True(t, fetchCalled, "Fetcher should be called on cache miss")
|
||||
|
||||
// Verify value was stored in cache
|
||||
cached, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, cached)
|
||||
})
|
||||
|
||||
t.Run("FetcherError", func(t *testing.T) {
|
||||
key := "error-key"
|
||||
expectedErr := errors.New("fetch failed")
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
assert.Nil(t, result)
|
||||
|
||||
// Verify nothing was stored in cache
|
||||
_, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated
|
||||
func TestSingleflightCache_Deduplication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "dedup-key"
|
||||
expectedValue := []byte("dedup-value")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 10
|
||||
var wg sync.WaitGroup
|
||||
results := make([][]byte, concurrency)
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same result
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.NoError(t, errs[i])
|
||||
assert.Equal(t, expectedValue, results[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once")
|
||||
|
||||
// Verify deduplication stats
|
||||
stats := cache.GetStats()
|
||||
deduped := stats["singleflight_deduplicated"].(int64)
|
||||
assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel
|
||||
func TestSingleflightCache_DifferentKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetchStarted := make(chan struct{}, 3)
|
||||
fetchComplete := make(chan struct{})
|
||||
|
||||
fetcher := func(key string) Fetcher {
|
||||
return func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
fetchStarted <- struct{}{}
|
||||
<-fetchComplete // Wait for signal
|
||||
return []byte("value-" + key), time.Minute, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Launch concurrent requests for different keys
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 3; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("key-%d", idx)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher(key))
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all fetches to start
|
||||
for i := 0; i < 3; i++ {
|
||||
<-fetchStarted
|
||||
}
|
||||
|
||||
// All 3 fetches should be running in parallel
|
||||
assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel")
|
||||
|
||||
// Release all fetches
|
||||
close(fetchComplete)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ContextCancellation tests context cancellation
|
||||
func TestSingleflightCache_ContextCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
key := "cancel-key"
|
||||
fetchStarted := make(chan struct{})
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
close(fetchStarted)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Start first request with long timeout
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx := context.Background()
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}()
|
||||
|
||||
// Wait for fetch to start
|
||||
<-fetchStarted
|
||||
|
||||
// Start second request with short timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated
|
||||
func TestSingleflightCache_ErrorPropagation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "error-prop-key"
|
||||
expectedErr := errors.New("intentional error")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 5
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same error
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.Error(t, errs[i])
|
||||
assert.Equal(t, expectedErr, errs[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load())
|
||||
}
|
||||
|
||||
// TestSingleflightCache_PassthroughMethods tests that passthrough methods work
|
||||
func TestSingleflightCache_PassthroughMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Set", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, _, exists, err := cache.Get(ctx, "set-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("set-value"), val)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, ttl, exists, err := cache.Get(ctx, "get-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("get-value"), val)
|
||||
assert.Greater(t, ttl, time.Duration(0))
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := cache.Delete(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := cache.Exists(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
exists, err := cache.Exists(ctx, "nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = cache.Exists(ctx, "exists-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err := cache.Exists(ctx, "clear-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
err := cache.Ping(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Stats tests statistics tracking
|
||||
func TestSingleflightCache_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Make some calls
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = cache.GetOrFetch(ctx, "stats-key", fetcher)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
stats := cache.GetStats()
|
||||
|
||||
// Check singleflight stats exist
|
||||
assert.Contains(t, stats, "singleflight_total_calls")
|
||||
assert.Contains(t, stats, "singleflight_deduplicated")
|
||||
assert.Contains(t, stats, "singleflight_dedup_rate")
|
||||
assert.Contains(t, stats, "singleflight_inflight")
|
||||
|
||||
// Verify values
|
||||
assert.Equal(t, int64(5), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(4), stats["singleflight_deduplicated"])
|
||||
|
||||
// Also check underlying backend stats are included
|
||||
assert.Contains(t, stats, "hits")
|
||||
assert.Contains(t, stats, "misses")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ResetStats tests stats reset
|
||||
func TestSingleflightCache_ResetStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Make some calls
|
||||
_, _ = cache.GetOrFetch(ctx, "key1", fetcher)
|
||||
_, _ = cache.GetOrFetch(ctx, "key2", fetcher)
|
||||
|
||||
stats := cache.GetStats()
|
||||
assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0))
|
||||
|
||||
// Reset stats
|
||||
cache.ResetStats()
|
||||
|
||||
stats = cache.GetStats()
|
||||
assert.Equal(t, int64(0), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(0), stats["singleflight_deduplicated"])
|
||||
}
|
||||
|
||||
// TestSingleflightCache_GetBackend tests GetBackend method
|
||||
func TestSingleflightCache_GetBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
assert.Equal(t, backend, cache.GetBackend())
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Sequential benchmarks sequential access
|
||||
func BenchmarkSingleflightCache_Sequential(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("key-%d", i%100)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access
|
||||
func BenchmarkSingleflightCache_Concurrent(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(time.Millisecond) // Simulate slow fetch
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario
|
||||
func BenchmarkSingleflightCache_HighContention(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// All goroutines hit the same key
|
||||
_, _ = cache.GetOrFetch(ctx, "hot-key", fetcher)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FileStore implements Store using file-based storage.
|
||||
// This is the default storage backend for backward compatibility with existing deployments.
|
||||
// For distributed environments, consider using RedisStore instead.
|
||||
type FileStore struct {
|
||||
basePath string
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFileStore creates a new file-based credentials store.
|
||||
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
|
||||
func NewFileStore(basePath string, logger Logger) *FileStore {
|
||||
if basePath == "" {
|
||||
basePath = "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
if logger == nil {
|
||||
logger = NoOpLogger()
|
||||
}
|
||||
return &FileStore{
|
||||
basePath: basePath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// BasePath returns the base path used for storing credentials
|
||||
func (s *FileStore) BasePath() string {
|
||||
return s.basePath
|
||||
}
|
||||
|
||||
// GetFilePath returns the file path for storing credentials for a specific provider.
|
||||
// For multi-tenant scenarios, each provider gets a separate file based on URL hash.
|
||||
func (s *FileStore) GetFilePath(providerURL string) string {
|
||||
if providerURL == "" {
|
||||
return s.basePath
|
||||
}
|
||||
|
||||
// Hash provider URL for filename safety and uniqueness
|
||||
hash := sha256.Sum256([]byte(providerURL))
|
||||
hashStr := hex.EncodeToString(hash[:8]) // Use first 8 bytes for shorter filename
|
||||
|
||||
ext := filepath.Ext(s.basePath)
|
||||
base := strings.TrimSuffix(s.basePath, ext)
|
||||
if ext == "" {
|
||||
ext = ".json"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s%s", base, hashStr, ext)
|
||||
}
|
||||
|
||||
// Save stores the client registration response to a file
|
||||
func (s *FileStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
if creds == nil {
|
||||
return fmt.Errorf("credentials cannot be nil")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
// Ensure parent directory exists
|
||||
dir := filepath.Dir(filePath)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create credentials directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(creds, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Write with restrictive permissions (owner read/write only)
|
||||
if err := os.WriteFile(filePath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Saved client credentials to %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves stored credentials from a file.
|
||||
// Returns nil, nil if no credentials file exists (not an error).
|
||||
func (s *FileStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via GetFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // No credentials file exists - not an error
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read credentials file: %w", err)
|
||||
}
|
||||
|
||||
var creds ClientRegistrationResponse
|
||||
if err := json.Unmarshal(data, &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Loaded client credentials from %s", filePath)
|
||||
return &creds, nil
|
||||
}
|
||||
|
||||
// Delete removes the credentials file for a provider
|
||||
func (s *FileStore) Delete(ctx context.Context, providerURL string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil // File doesn't exist, nothing to delete
|
||||
}
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Deleted client credentials from %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
func (s *FileStore) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
filePath := s.GetFilePath(providerURL)
|
||||
|
||||
_, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to check credentials file: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache defines the interface for cache operations needed by RedisStore.
|
||||
// This allows the main package to provide a cache implementation without
|
||||
// creating circular dependencies.
|
||||
type Cache interface {
|
||||
// Get retrieves a value from the cache
|
||||
Get(key string) (any, bool)
|
||||
// Set stores a value in the cache with a TTL
|
||||
Set(key string, value any, ttl time.Duration) error
|
||||
// Delete removes a value from the cache
|
||||
Delete(key string)
|
||||
}
|
||||
|
||||
// RedisStore implements Store using a Cache-backed storage.
|
||||
// This storage backend enables sharing DCR credentials across multiple Traefik instances
|
||||
// in distributed environments (e.g., Kubernetes with multiple ingress pods).
|
||||
type RedisStore struct {
|
||||
cache Cache
|
||||
keyPrefix string
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new cache-backed credentials store.
|
||||
// The cache should be configured with a Redis backend for distributed storage.
|
||||
// If keyPrefix is empty, defaults to "dcr:creds:"
|
||||
func NewRedisStore(cache Cache, keyPrefix string, logger Logger) *RedisStore {
|
||||
if keyPrefix == "" {
|
||||
keyPrefix = "dcr:creds:"
|
||||
}
|
||||
if logger == nil {
|
||||
logger = NoOpLogger()
|
||||
}
|
||||
return &RedisStore{
|
||||
cache: cache,
|
||||
keyPrefix: keyPrefix,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// makeKey creates a unique cache key for a provider URL.
|
||||
// Uses SHA256 hash of the provider URL for consistent key generation across nodes.
|
||||
func (s *RedisStore) makeKey(providerURL string) string {
|
||||
if providerURL == "" {
|
||||
return s.keyPrefix + "default"
|
||||
}
|
||||
hash := sha256.Sum256([]byte(providerURL))
|
||||
return s.keyPrefix + hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Save stores the client registration response in the cache.
|
||||
// TTL is calculated based on client_secret_expires_at if available.
|
||||
func (s *RedisStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
|
||||
if creds == nil {
|
||||
return fmt.Errorf("credentials cannot be nil")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
|
||||
// Calculate TTL based on client_secret_expires_at if available
|
||||
ttl := 30 * 24 * time.Hour // Default: 30 days
|
||||
if creds.ClientSecretExpiresAt > 0 {
|
||||
expiresAt := time.Unix(creds.ClientSecretExpiresAt, 0)
|
||||
ttl = time.Until(expiresAt)
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("credentials already expired")
|
||||
}
|
||||
// Add a small buffer to ensure we don't serve expired credentials
|
||||
if ttl > time.Minute {
|
||||
ttl -= time.Minute
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize credentials to JSON for storage
|
||||
data, err := json.Marshal(creds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Store as string in cache (will be serialized by the cache backend)
|
||||
if err := s.cache.Set(key, string(data), ttl); err != nil {
|
||||
return fmt.Errorf("failed to store credentials in cache: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Saved client credentials to cache with key %s (TTL: %v)", key, ttl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load retrieves stored credentials from the cache.
|
||||
// Returns nil, nil if no credentials exist (not an error).
|
||||
func (s *RedisStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
|
||||
value, exists := s.cache.Get(key)
|
||||
if !exists {
|
||||
return nil, nil // No credentials stored - not an error
|
||||
}
|
||||
|
||||
// Handle different value types from cache
|
||||
var jsonData string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
jsonData = v
|
||||
case []byte:
|
||||
jsonData = string(v)
|
||||
default:
|
||||
// Try to see if it's already the struct (from local cache)
|
||||
if creds, ok := value.(*ClientRegistrationResponse); ok {
|
||||
return creds, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected credentials type in cache: %T", value)
|
||||
}
|
||||
|
||||
var creds ClientRegistrationResponse
|
||||
if err := json.Unmarshal([]byte(jsonData), &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials from cache: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debugf("Loaded client credentials from cache with key %s", key)
|
||||
return &creds, nil
|
||||
}
|
||||
|
||||
// Delete removes stored credentials from the cache
|
||||
func (s *RedisStore) Delete(ctx context.Context, providerURL string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
s.cache.Delete(key)
|
||||
|
||||
s.logger.Debugf("Deleted client credentials from cache with key %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if credentials exist in the cache for a provider
|
||||
func (s *RedisStore) Exists(ctx context.Context, providerURL string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
key := s.makeKey(providerURL)
|
||||
_, exists := s.cache.Get(key)
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
// Package dcrstorage provides storage backends for OIDC Dynamic Client Registration credentials.
|
||||
// It supports both file-based and Redis-based storage for persisting client credentials
|
||||
// across application restarts and distributed deployments.
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// StorageBackend represents the type of storage backend for DCR credentials
|
||||
type StorageBackend string
|
||||
|
||||
const (
|
||||
// StorageBackendFile uses file-based storage (default for backward compatibility)
|
||||
StorageBackendFile StorageBackend = "file"
|
||||
|
||||
// StorageBackendRedis uses Redis for distributed storage
|
||||
StorageBackendRedis StorageBackend = "redis"
|
||||
|
||||
// StorageBackendAuto automatically selects Redis if available, otherwise file
|
||||
StorageBackendAuto StorageBackend = "auto"
|
||||
)
|
||||
|
||||
// Logger interface for DCR storage operations
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...any)
|
||||
Info(msg string)
|
||||
Infof(format string, args ...any)
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
|
||||
type ClientRegistrationResponse struct {
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
}
|
||||
|
||||
// Store defines the interface for storing DCR credentials.
|
||||
// This abstraction allows different storage backends (file, Redis) to be used
|
||||
// for persisting OIDC Dynamic Client Registration credentials across nodes.
|
||||
type Store interface {
|
||||
// Save stores the client registration response for a provider
|
||||
// The providerURL is used as a key to support multi-tenant scenarios
|
||||
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
|
||||
|
||||
// Load retrieves stored credentials for a provider
|
||||
// Returns nil, nil if no credentials exist (not an error)
|
||||
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
|
||||
|
||||
// Delete removes stored credentials for a provider
|
||||
Delete(ctx context.Context, providerURL string) error
|
||||
|
||||
// Exists checks if credentials exist for a provider
|
||||
Exists(ctx context.Context, providerURL string) (bool, error)
|
||||
}
|
||||
|
||||
// noOpLogger is a no-op implementation of Logger for default use
|
||||
type noOpLogger struct{}
|
||||
|
||||
func (n noOpLogger) Debug(msg string) {}
|
||||
func (n noOpLogger) Debugf(format string, args ...any) {}
|
||||
func (n noOpLogger) Info(msg string) {}
|
||||
func (n noOpLogger) Infof(format string, args ...any) {}
|
||||
func (n noOpLogger) Error(msg string) {}
|
||||
func (n noOpLogger) Errorf(format string, args ...any) {}
|
||||
|
||||
// NoOpLogger returns a no-op logger instance
|
||||
func NoOpLogger() Logger {
|
||||
return noOpLogger{}
|
||||
}
|
||||
@@ -0,0 +1,464 @@
|
||||
package dcrstorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockCache implements Cache for testing
|
||||
type mockCache struct {
|
||||
data map[string]cacheEntry
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
value any
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newMockCache() *mockCache {
|
||||
return &mockCache{data: make(map[string]cacheEntry)}
|
||||
}
|
||||
|
||||
func (m *mockCache) Get(key string) (any, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
entry, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return nil, false
|
||||
}
|
||||
return entry.value, true
|
||||
}
|
||||
|
||||
func (m *mockCache) Set(key string, value any, ttl time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = cacheEntry{
|
||||
value: value,
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCache) Delete(key string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
}
|
||||
|
||||
func TestFileStore_SaveLoad(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "test-access-token",
|
||||
RegistrationClientURI: "https://example.com/register/test-client-id",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||||
ResponseTypes: []string{"code"},
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
|
||||
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
tempDir2 := t.TempDir()
|
||||
store2 := NewFileStore(filepath.Join(tempDir2, "nonexistent.json"), nil)
|
||||
|
||||
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent file: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
|
||||
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if exists {
|
||||
t.Error("Expected credentials to not exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete non-existent credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete should not error for non-existent: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_MultiProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
provider1 := "https://auth1.example.com"
|
||||
provider2 := "https://auth2.example.com"
|
||||
|
||||
creds1 := &ClientRegistrationResponse{
|
||||
ClientID: "client-1",
|
||||
ClientSecret: "secret-1",
|
||||
}
|
||||
creds2 := &ClientRegistrationResponse{
|
||||
ClientID: "client-2",
|
||||
ClientSecret: "secret-2",
|
||||
}
|
||||
|
||||
if err := store.Save(ctx, provider1, creds1); err != nil {
|
||||
t.Fatalf("Failed to save creds1: %v", err)
|
||||
}
|
||||
if err := store.Save(ctx, provider2, creds2); err != nil {
|
||||
t.Fatalf("Failed to save creds2: %v", err)
|
||||
}
|
||||
|
||||
loaded1, err := store.Load(ctx, provider1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds1: %v", err)
|
||||
}
|
||||
if loaded1.ClientID != "client-1" {
|
||||
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
|
||||
}
|
||||
|
||||
loaded2, err := store.Load(ctx, provider2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load creds2: %v", err)
|
||||
}
|
||||
if loaded2.ClientID != "client-2" {
|
||||
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
|
||||
}
|
||||
|
||||
if err := store.Delete(ctx, provider1); err != nil {
|
||||
t.Fatalf("Failed to delete creds1: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, provider2)
|
||||
if !exists {
|
||||
t.Error("Provider 2 credentials should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
concurrency := 10
|
||||
|
||||
for range concurrency {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = store.Save(ctx, providerURL, creds)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for range concurrency {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = store.Load(ctx, providerURL)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after concurrent access: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test-client" {
|
||||
t.Error("Credentials corrupted after concurrent access")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty provider URL uses default path", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
err := store.Save(ctx, "", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Save with empty provider URL failed: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Load with empty provider URL failed: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials with empty provider URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_DefaultPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := NewFileStore("", nil)
|
||||
|
||||
if store.BasePath() == "" {
|
||||
t.Error("Expected default base path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisStore_WithMockCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
testCreds := &ClientRegistrationResponse{
|
||||
ClientID: "redis-test-client",
|
||||
ClientSecret: "redis-test-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
|
||||
RegistrationAccessToken: "redis-test-token",
|
||||
RedirectURIs: []string{"https://app.example.com/callback"},
|
||||
}
|
||||
|
||||
t.Run("save and load credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, providerURL, testCreds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load credentials: %v", err)
|
||||
}
|
||||
|
||||
if loaded == nil {
|
||||
t.Fatal("Expected credentials but got nil")
|
||||
}
|
||||
if loaded.ClientID != testCreds.ClientID {
|
||||
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
|
||||
}
|
||||
if loaded.ClientSecret != testCreds.ClientSecret {
|
||||
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exists check", func(t *testing.T) {
|
||||
exists, err := store.Exists(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Exists check failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Error("Expected credentials to exist")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete credentials", func(t *testing.T) {
|
||||
err := store.Delete(ctx, providerURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete credentials: %v", err)
|
||||
}
|
||||
|
||||
exists, _ := store.Exists(ctx, providerURL)
|
||||
if exists {
|
||||
t.Error("Expected credentials to be deleted")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load non-existent credentials", func(t *testing.T) {
|
||||
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error for non-existent: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("Expected nil for non-existent credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisStore_TTLFromExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("expired credentials should fail", func(t *testing.T) {
|
||||
expiredCreds := &ClientRegistrationResponse{
|
||||
ClientID: "expired-client",
|
||||
ClientSecret: "expired-secret",
|
||||
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired credentials")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
|
||||
creds := &ClientRegistrationResponse{
|
||||
ClientID: "no-expiry-client",
|
||||
ClientSecret: "no-expiry-secret",
|
||||
ClientSecretExpiresAt: 0,
|
||||
}
|
||||
|
||||
err := store.Save(ctx, "https://noexpiry.example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save credentials without expiry: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisStore_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := newMockCache()
|
||||
store := NewRedisStore(cache, "", nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("save nil credentials", func(t *testing.T) {
|
||||
err := store.Save(ctx, "https://example.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFileStore_CorruptedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
basePath := filepath.Join(tempDir, "credentials.json")
|
||||
store := NewFileStore(basePath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
filePath := store.GetFilePath(providerURL)
|
||||
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
|
||||
t.Fatalf("Failed to write corrupted file: %v", err)
|
||||
}
|
||||
|
||||
_, err := store.Load(ctx, providerURL)
|
||||
if err == nil {
|
||||
t.Error("Expected error for corrupted JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStore_DirectoryCreation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir := t.TempDir()
|
||||
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
|
||||
store := NewFileStore(deepPath, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
creds := &ClientRegistrationResponse{ClientID: "test"}
|
||||
|
||||
err := store.Save(ctx, "https://example.com", creds)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save with nested directory: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := store.Load(ctx, "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load after nested directory creation: %v", err)
|
||||
}
|
||||
if loaded == nil || loaded.ClientID != "test" {
|
||||
t.Error("Failed to load credentials from nested directory")
|
||||
}
|
||||
}
|
||||
@@ -433,6 +433,19 @@ func (t *TraefikOidc) performDynamicClientRegistration() {
|
||||
t.dcrConfig,
|
||||
t.providerURL,
|
||||
)
|
||||
|
||||
// Set up storage backend for credentials persistence
|
||||
if t.dcrConfig.PersistCredentials {
|
||||
cacheManager := GetGlobalCacheManagerWithConfig(t.goroutineWG, nil)
|
||||
store, err := NewDCRCredentialsStore(t.dcrConfig, cacheManager, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to create DCR credentials store: %v", err)
|
||||
// Continue without persistence - registration will still work
|
||||
} else {
|
||||
t.dynamicClientRegistrar.SetStore(store)
|
||||
t.logger.Debugf("DCR credentials store initialized with backend: %s", t.dcrConfig.StorageBackend)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get registration endpoint (from metadata or config override)
|
||||
|
||||
+9
-2
@@ -98,8 +98,15 @@ type DynamicClientRegistrationConfig struct {
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
// StorageBackend specifies where to store DCR credentials: "file", "redis", or "auto"
|
||||
// - "file": Use file-based storage (default for backward compatibility)
|
||||
// - "redis": Use Redis exclusively (fails if Redis unavailable)
|
||||
// - "auto": Use Redis if available, fallback to file (default)
|
||||
StorageBackend string `json:"storageBackend,omitempty"`
|
||||
// RedisKeyPrefix is the prefix for Redis keys when using Redis storage (default: "dcr:creds:")
|
||||
RedisKeyPrefix string `json:"redisKeyPrefix,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
}
|
||||
|
||||
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
|
||||
|
||||
+1
-1
@@ -436,7 +436,7 @@ func (c *UniversalCache) Clear() {
|
||||
c.currentSize = 0
|
||||
c.currentMemory = 0
|
||||
|
||||
c.logger.Infof("UniversalCache[%s]: Cleared all items", c.config.Type)
|
||||
c.logger.Debugf("UniversalCache[%s]: Cleared all items", c.config.Type)
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
|
||||
@@ -13,20 +13,21 @@ import (
|
||||
// It runs a single consolidated cleanup goroutine for all caches, reducing
|
||||
// goroutine count and CPU overhead compared to per-cache cleanup routines.
|
||||
type UniversalCacheManager struct {
|
||||
sharedBackend backends.CacheBackend
|
||||
ctx context.Context
|
||||
tokenTypeCache *UniversalCache
|
||||
jwkCache *UniversalCache
|
||||
sessionCache *UniversalCache
|
||||
introspectionCache *UniversalCache
|
||||
tokenCache *UniversalCache
|
||||
metadataCache *UniversalCache
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
cleanupStarted bool
|
||||
sharedBackend backends.CacheBackend
|
||||
ctx context.Context
|
||||
tokenTypeCache *UniversalCache
|
||||
jwkCache *UniversalCache
|
||||
sessionCache *UniversalCache
|
||||
introspectionCache *UniversalCache
|
||||
tokenCache *UniversalCache
|
||||
metadataCache *UniversalCache
|
||||
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
cleanupStarted bool
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -349,6 +350,19 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// DCR credentials cache - CRITICAL for distributed DCR across multiple nodes
|
||||
// Uses Redis backend to share client credentials across all Traefik replicas
|
||||
manager.dcrCredentialsCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 100, // Few providers expected
|
||||
DefaultTTL: 30 * 24 * time.Hour, // 30 days default (credentials are long-lived)
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("dcr"),
|
||||
)
|
||||
|
||||
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
|
||||
}
|
||||
|
||||
@@ -396,6 +410,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
|
||||
m.sessionCache,
|
||||
m.introspectionCache,
|
||||
m.tokenTypeCache,
|
||||
m.dcrCredentialsCache,
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
@@ -458,6 +473,13 @@ func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache {
|
||||
return m.tokenTypeCache
|
||||
}
|
||||
|
||||
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
|
||||
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.dcrCredentialsCache
|
||||
}
|
||||
|
||||
// Close shuts down all caches and the consolidated cleanup routine
|
||||
func (m *UniversalCacheManager) Close() error {
|
||||
// Stop the consolidated cleanup routine first
|
||||
@@ -473,7 +495,7 @@ func (m *UniversalCacheManager) Close() error {
|
||||
|
||||
// Close all caches first (they won't close the shared backend)
|
||||
for _, cache := range []*UniversalCache{
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache,
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache,
|
||||
} {
|
||||
if cache != nil {
|
||||
_ = cache.Close() // Safe to ignore: best effort cache cleanup
|
||||
|
||||
Reference in New Issue
Block a user