... and another all nighter.

This commit is contained in:
2025-10-22 10:46:12 +01:00
parent c16b94a9c2
commit c05932cf8a
27 changed files with 7318 additions and 600 deletions
+258
View File
@@ -0,0 +1,258 @@
// Package config provides backward compatibility for legacy configuration
package config
import (
"fmt"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/compat"
"github.com/lukaszraczylo/traefikoidc/internal/features"
)
// LegacyAdapter provides backward compatibility for old Config struct
type LegacyAdapter struct {
unified *UnifiedConfig
adapter *compat.ConfigAdapter
}
// NewLegacyAdapter creates a new legacy adapter from unified config
func NewLegacyAdapter(unified *UnifiedConfig) *LegacyAdapter {
adapter := compat.NewConfigAdapter(unified)
// Register getters for commonly used fields
adapter.RegisterGetter("ProviderURL", func() interface{} {
return unified.Provider.IssuerURL
})
adapter.RegisterGetter("ClientID", func() interface{} {
return unified.Provider.ClientID
})
adapter.RegisterGetter("ClientSecret", func() interface{} {
return unified.Provider.ClientSecret
})
adapter.RegisterGetter("CallbackURL", func() interface{} {
return unified.Provider.RedirectURL
})
adapter.RegisterGetter("LogoutURL", func() interface{} {
return unified.Provider.LogoutURL
})
adapter.RegisterGetter("PostLogoutRedirectURI", func() interface{} {
return unified.Provider.PostLogoutRedirectURI
})
adapter.RegisterGetter("SessionEncryptionKey", func() interface{} {
return unified.Session.EncryptionKey
})
adapter.RegisterGetter("ForceHTTPS", func() interface{} {
return unified.Security.ForceHTTPS
})
adapter.RegisterGetter("LogLevel", func() interface{} {
return unified.Logging.Level
})
adapter.RegisterGetter("Scopes", func() interface{} {
return unified.Provider.Scopes
})
adapter.RegisterGetter("OverrideScopes", func() interface{} {
return unified.Provider.OverrideScopes
})
adapter.RegisterGetter("AllowedUsers", func() interface{} {
return unified.Security.AllowedUsers
})
adapter.RegisterGetter("AllowedUserDomains", func() interface{} {
return unified.Security.AllowedUserDomains
})
adapter.RegisterGetter("AllowedRolesAndGroups", func() interface{} {
return unified.Security.AllowedRolesAndGroups
})
adapter.RegisterGetter("ExcludedURLs", func() interface{} {
return unified.Security.ExcludedURLs
})
adapter.RegisterGetter("EnablePKCE", func() interface{} {
return unified.Security.EnablePKCE
})
adapter.RegisterGetter("RateLimit", func() interface{} {
return unified.RateLimit.RequestsPerSecond
})
adapter.RegisterGetter("RefreshGracePeriodSeconds", func() interface{} {
return int(unified.Token.RefreshGracePeriod.Seconds())
})
adapter.RegisterGetter("CookieDomain", func() interface{} {
return unified.Session.Domain
})
adapter.RegisterGetter("SecurityHeaders", func() interface{} {
return unified.Security.Headers
})
return &LegacyAdapter{
unified: unified,
adapter: adapter,
}
}
// ToOldConfig converts unified config to old Config struct format
func (la *LegacyAdapter) ToOldConfig() *Config {
// Use feature flags to determine behavior
if !features.IsUnifiedConfigEnabled() {
// Return existing Config if unified config not enabled
return CreateConfig()
}
cfg := &Config{
ProviderURL: la.unified.Provider.IssuerURL,
ClientID: la.unified.Provider.ClientID,
ClientSecret: la.unified.Provider.ClientSecret,
CallbackURL: la.unified.Provider.RedirectURL,
LogoutURL: la.unified.Provider.LogoutURL,
PostLogoutRedirectURI: la.unified.Provider.PostLogoutRedirectURI,
SessionEncryptionKey: la.unified.Session.EncryptionKey,
ForceHTTPS: la.unified.Security.ForceHTTPS,
LogLevel: la.unified.Logging.Level,
Scopes: la.unified.Provider.Scopes,
OverrideScopes: la.unified.Provider.OverrideScopes,
AllowedUsers: la.unified.Security.AllowedUsers,
AllowedUserDomains: la.unified.Security.AllowedUserDomains,
AllowedRolesAndGroups: la.unified.Security.AllowedRolesAndGroups,
ExcludedURLs: la.unified.Security.ExcludedURLs,
EnablePKCE: la.unified.Security.EnablePKCE,
RateLimit: la.unified.RateLimit.RequestsPerSecond,
RefreshGracePeriodSeconds: int(la.unified.Token.RefreshGracePeriod.Seconds()),
Headers: la.convertHeaders(),
CookieDomain: la.unified.Session.Domain,
SecurityHeaders: la.unified.Security.Headers,
}
return cfg
}
// convertHeaders converts unified header config to old format
func (la *LegacyAdapter) convertHeaders() []HeaderConfig {
headers := make([]HeaderConfig, 0)
for name, value := range la.unified.Middleware.CustomHeaders {
headers = append(headers, HeaderConfig{
Name: name,
Value: value,
})
}
return headers
}
// FromOldConfig creates unified config from old Config struct
func FromOldConfig(old *Config) *UnifiedConfig {
unified := NewUnifiedConfig()
// Map provider settings
unified.Provider.IssuerURL = old.ProviderURL
unified.Provider.ClientID = old.ClientID
unified.Provider.ClientSecret = old.ClientSecret
unified.Provider.RedirectURL = old.CallbackURL
unified.Provider.LogoutURL = old.LogoutURL
unified.Provider.PostLogoutRedirectURI = old.PostLogoutRedirectURI
unified.Provider.Scopes = old.Scopes
unified.Provider.OverrideScopes = old.OverrideScopes
// Map session settings
unified.Session.EncryptionKey = old.SessionEncryptionKey
unified.Session.Domain = old.CookieDomain
// Map security settings
unified.Security.ForceHTTPS = old.ForceHTTPS
unified.Security.EnablePKCE = old.EnablePKCE
unified.Security.AllowedUsers = old.AllowedUsers
unified.Security.AllowedUserDomains = old.AllowedUserDomains
unified.Security.AllowedRolesAndGroups = old.AllowedRolesAndGroups
unified.Security.ExcludedURLs = old.ExcludedURLs
unified.Security.Headers = old.SecurityHeaders
// Map rate limiting
unified.RateLimit.RequestsPerSecond = old.RateLimit
unified.RateLimit.Enabled = old.RateLimit > 0
// Map token settings
unified.Token.RefreshGracePeriod = timeSecondsToDuration(old.RefreshGracePeriodSeconds)
// Map logging
unified.Logging.Level = old.LogLevel
// Map custom headers
if len(old.Headers) > 0 {
unified.Middleware.CustomHeaders = make(map[string]string)
for _, header := range old.Headers {
unified.Middleware.CustomHeaders[header.Name] = header.Value
}
}
// Store original config in legacy field for reference
unified.Legacy["original"] = old
return unified
}
// timeSecondsToDuration converts seconds to time.Duration
func timeSecondsToDuration(seconds int) time.Duration {
return time.Duration(seconds) * time.Second
}
// GetConfigInterface returns appropriate config based on feature flag
func GetConfigInterface() interface{} {
if features.IsUnifiedConfigEnabled() {
return NewUnifiedConfig()
}
return CreateConfig()
}
// ValidateConfig validates config based on feature flag
func ValidateConfig(cfg interface{}) error {
if features.IsUnifiedConfigEnabled() {
if unified, ok := cfg.(*UnifiedConfig); ok {
return unified.Validate()
}
}
// Fall back to old validation if available
if old, ok := cfg.(*Config); ok {
return old.Validate()
}
return nil
}
// Add Validate method to old Config for compatibility
func (c *Config) Validate() error {
var errors ValidationErrors
// Basic validation for old config
if c.ProviderURL == "" {
errors = append(errors, ValidationError{
Field: "ProviderURL",
Message: "provider URL is required",
})
}
if c.ClientID == "" {
errors = append(errors, ValidationError{
Field: "ClientID",
Message: "client ID is required",
})
}
if c.ClientSecret == "" && !c.EnablePKCE {
errors = append(errors, ValidationError{
Field: "ClientSecret",
Message: "client secret is required (or enable PKCE)",
})
}
if c.SessionEncryptionKey != "" && len(c.SessionEncryptionKey) < minEncryptionKeyLength {
errors = append(errors, ValidationError{
Field: "SessionEncryptionKey",
Message: fmt.Sprintf("encryption key must be at least %d characters", minEncryptionKeyLength),
Value: len(c.SessionEncryptionKey),
})
}
if len(errors) > 0 {
return errors
}
return nil
}
+276
View File
@@ -0,0 +1,276 @@
// Package config provides default values and initialization for unified configuration
package config
import (
"time"
)
// NewUnifiedConfig creates a new unified configuration with sensible defaults
func NewUnifiedConfig() *UnifiedConfig {
return &UnifiedConfig{
Provider: DefaultProviderConfig(),
Session: DefaultSessionConfig(),
Token: DefaultTokenConfig(),
Redis: *DefaultRedisConfig(), // Using existing DefaultRedisConfig
Security: DefaultSecurityConfig(),
Middleware: DefaultMiddlewareConfig(),
Cache: DefaultCacheConfig(),
RateLimit: DefaultRateLimitConfig(),
Logging: DefaultLoggingConfig(),
Metrics: DefaultMetricsConfig(),
Health: DefaultHealthConfig(),
Transport: DefaultTransportConfig(),
Pool: DefaultPoolConfig(),
Circuit: DefaultCircuitConfig(),
Legacy: make(map[string]interface{}),
}
}
// DefaultProviderConfig returns default provider configuration
func DefaultProviderConfig() ProviderConfig {
return ProviderConfig{
Scopes: []string{"openid", "profile", "email"},
OverrideScopes: false,
CustomClaims: make(map[string]string),
JWKCachePeriod: 24 * time.Hour,
MetadataCacheTTL: 24 * time.Hour,
Discovery: true,
}
}
// DefaultSessionConfig returns default session configuration
func DefaultSessionConfig() SessionConfig {
return SessionConfig{
Name: "oidc_session",
MaxAge: 86400, // 24 hours
ChunkSize: 4000, // Safe size for cookies
MaxChunks: 5,
Path: "/",
Secure: true,
HttpOnly: true,
SameSite: "Lax",
StorageType: "cookie",
CleanupInterval: 1 * time.Hour,
}
}
// DefaultTokenConfig returns default token configuration
func DefaultTokenConfig() TokenConfig {
return TokenConfig{
AccessTokenTTL: 1 * time.Hour,
RefreshTokenTTL: 24 * time.Hour,
RefreshGracePeriod: 60 * time.Second,
ValidationMode: "jwt",
CacheEnabled: true,
CacheTTL: 5 * time.Minute,
CacheNegativeTTL: 30 * time.Second,
ValidateSignature: true,
ValidateExpiry: true,
ValidateAudience: true,
ValidateIssuer: true,
RequiredClaims: []string{"sub", "iat", "exp"},
ClockSkew: 5 * time.Minute,
}
}
// DefaultSecurityConfig returns default security configuration
func DefaultSecurityConfig() SecurityConfig {
return SecurityConfig{
ForceHTTPS: true,
EnablePKCE: true,
AllowedUsers: []string{},
AllowedUserDomains: []string{},
AllowedRolesAndGroups: []string{},
ExcludedURLs: []string{
"/favicon.ico",
"/robots.txt",
"/health",
"/.well-known/",
"/metrics",
"/ping",
"/static/",
"/assets/",
"/js/",
"/css/",
"/images/",
"/fonts/",
},
Headers: createDefaultSecurityConfig(),
CSRFProtection: true,
CSRFTokenName: "csrf_token",
CSRFTokenTTL: 1 * time.Hour,
MaxLoginAttempts: 5,
LockoutDuration: 15 * time.Minute,
RequireMFA: false,
}
}
// DefaultMiddlewareConfig returns default middleware configuration
func DefaultMiddlewareConfig() MiddlewareConfig {
return MiddlewareConfig{
Priority: 1000,
SkipPaths: []string{},
RequirePaths: []string{},
PassthroughMode: false,
MaxRequestSize: 10 * 1024 * 1024, // 10MB
RequestTimeout: 30 * time.Second,
IdleTimeout: 90 * time.Second,
CustomHeaders: make(map[string]string),
RemoveHeaders: []string{},
}
}
// DefaultCacheConfig returns default cache configuration
func DefaultCacheConfig() CacheConfig {
return CacheConfig{
Enabled: true,
Type: "memory",
DefaultTTL: 5 * time.Minute,
MaxEntries: 10000,
MaxEntrySize: 1024 * 1024, // 1MB
EvictionPolicy: "lru",
CleanupInterval: 10 * time.Minute,
Namespace: "traefikoidc",
Compression: false,
Serialization: "json",
}
}
// DefaultRateLimitConfig returns default rate limiting configuration
func DefaultRateLimitConfig() RateLimitConfig {
return RateLimitConfig{
Enabled: false,
RequestsPerSecond: 10,
Burst: 20,
StorageType: "memory",
WindowDuration: 1 * time.Minute,
KeyType: "ip",
CustomKeyFunc: "",
WhitelistIPs: []string{},
WhitelistUsers: []string{},
}
}
// DefaultLoggingConfig returns default logging configuration
func DefaultLoggingConfig() LoggingConfig {
return LoggingConfig{
Level: "info",
Format: "json",
Output: "stdout",
FilePath: "",
FilterSensitive: true,
MaskFields: []string{
"password",
"secret",
"token",
"key",
"authorization",
"cookie",
},
BufferSize: 8192,
FlushInterval: 5 * time.Second,
AuditEnabled: false,
AuditEvents: []string{
"login",
"logout",
"token_refresh",
"auth_failure",
},
}
}
// DefaultMetricsConfig returns default metrics configuration
func DefaultMetricsConfig() MetricsConfig {
return MetricsConfig{
Enabled: false,
Provider: "prometheus",
Endpoint: "/metrics",
Namespace: "traefikoidc",
Subsystem: "middleware",
CollectInterval: 10 * time.Second,
Histograms: true,
Labels: make(map[string]string),
}
}
// DefaultHealthConfig returns default health check configuration
func DefaultHealthConfig() HealthConfig {
return HealthConfig{
Enabled: true,
Path: "/health",
CheckInterval: 30 * time.Second,
Timeout: 5 * time.Second,
CheckProvider: true,
CheckRedis: true,
CheckCache: true,
MaxLatency: 1 * time.Second,
MinMemory: 100 * 1024 * 1024, // 100MB
}
}
// DefaultTransportConfig returns default HTTP transport configuration
func DefaultTransportConfig() TransportConfig {
return TransportConfig{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
MaxConnsPerHost: 0, // No limit
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
DisableKeepAlives: false,
DisableCompression: false,
TLSInsecureSkipVerify: false,
TLSMinVersion: "TLS1.2",
TLSCipherSuites: []string{},
ProxyURL: "",
NoProxy: []string{},
}
}
// DefaultPoolConfig returns default connection pool configuration
func DefaultPoolConfig() PoolConfig {
return PoolConfig{
Enabled: true,
Size: 10,
MinSize: 2,
MaxSize: 50,
MaxAge: 30 * time.Minute,
IdleTimeout: 5 * time.Minute,
WaitTimeout: 5 * time.Second,
HealthCheckInterval: 30 * time.Second,
MaxRetries: 3,
}
}
// DefaultCircuitConfig returns default circuit breaker configuration
func DefaultCircuitConfig() CircuitConfig {
return CircuitConfig{
Enabled: true,
MaxRequests: 100,
Interval: 10 * time.Second,
Timeout: 60 * time.Second,
ConsecutiveFailures: 5,
FailureRatio: 0.5,
OnOpen: "reject",
OnHalfOpen: "passthrough",
MetricsEnabled: true,
LogStateChanges: true,
}
}
// MergeWithDefaults merges a partial configuration with defaults
func MergeWithDefaults(partial *UnifiedConfig) *UnifiedConfig {
if partial == nil {
return NewUnifiedConfig()
}
// Ensure Legacy field is initialized
if partial.Legacy == nil {
partial.Legacy = make(map[string]interface{})
}
// TODO: Implement deep merge logic with defaults
// For now, just return the partial config
return partial
}
+396
View File
@@ -0,0 +1,396 @@
// Package config provides configuration loading and merging logic
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"github.com/lukaszraczylo/traefikoidc/internal/features"
"gopkg.in/yaml.v3"
)
// ConfigLoader handles loading configuration from various sources
type ConfigLoader struct {
migrator *ConfigMigrator
envPrefix string
configPaths []string
}
// NewConfigLoader creates a new configuration loader
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
migrator: NewConfigMigrator(),
envPrefix: "TRAEFIKOIDC_",
configPaths: getDefaultConfigPaths(),
}
}
// getDefaultConfigPaths returns default configuration file paths to check
func getDefaultConfigPaths() []string {
return []string{
"traefik-oidc.yaml",
"traefik-oidc.yml",
"traefik-oidc.json",
"config.yaml",
"config.yml",
"config.json",
"/etc/traefik-oidc/config.yaml",
"/etc/traefik-oidc/config.json",
}
}
// Load loads configuration from all available sources
func (l *ConfigLoader) Load() (*UnifiedConfig, error) {
// Start with defaults
config := NewUnifiedConfig()
// Try to load from file
if fileConfig, err := l.LoadFromFile(); err == nil && fileConfig != nil {
config = l.mergeConfigs(config, fileConfig)
}
// Load from environment variables
l.LoadFromEnv(config)
// Validate the final configuration
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("configuration validation failed: %w", err)
}
return config, nil
}
// LoadFromFile loads configuration from a file
func (l *ConfigLoader) LoadFromFile(paths ...string) (*UnifiedConfig, error) {
// Use provided paths or default paths
searchPaths := paths
if len(searchPaths) == 0 {
searchPaths = l.configPaths
}
// Check for config file in environment variable
if envPath := os.Getenv(l.envPrefix + "CONFIG_FILE"); envPath != "" {
searchPaths = append([]string{envPath}, searchPaths...)
}
// Try each path
for _, path := range searchPaths {
if _, err := os.Stat(path); err == nil {
return l.loadFile(path)
}
}
// No config file found, not an error (use defaults)
return nil, nil
}
// loadFile loads a specific configuration file
func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
// Clean and validate path to prevent traversal attacks
cleanPath := filepath.Clean(path)
// Check for path traversal attempts
if strings.Contains(cleanPath, "..") {
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
}
// Ensure the path is within expected directories (current dir or subdirs)
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
}
// Read the file with validated path
data, err := os.ReadFile(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
}
// Check if unified config is enabled
if features.IsUnifiedConfigEnabled() {
// Use migrator to handle any version
config, warnings, err := l.migrator.Migrate(data)
if err != nil {
return nil, fmt.Errorf("failed to migrate config from %s: %w", path, err)
}
// Log warnings
for _, warning := range warnings {
// In production, use proper logging
fmt.Printf("Config Warning (%s): %s\n", path, warning)
}
return config, nil
}
// Legacy path: load old config and convert
ext := strings.ToLower(filepath.Ext(path))
var oldConfig Config
switch ext {
case ".json":
if err := json.Unmarshal(data, &oldConfig); err != nil {
return nil, fmt.Errorf("failed to parse JSON config: %w", err)
}
case ".yaml", ".yml":
if err := yaml.Unmarshal(data, &oldConfig); err != nil {
return nil, fmt.Errorf("failed to parse YAML config: %w", err)
}
default:
return nil, fmt.Errorf("unsupported config file extension: %s", ext)
}
return FromOldConfig(&oldConfig), nil
}
// LoadFromEnv loads configuration from environment variables
func (l *ConfigLoader) LoadFromEnv(config *UnifiedConfig) {
// Provider configuration
l.loadEnvString(&config.Provider.IssuerURL, "PROVIDER_ISSUER_URL", "PROVIDER_URL")
l.loadEnvString(&config.Provider.ClientID, "PROVIDER_CLIENT_ID", "CLIENT_ID")
l.loadEnvString(&config.Provider.ClientSecret, "PROVIDER_CLIENT_SECRET", "CLIENT_SECRET")
l.loadEnvString(&config.Provider.RedirectURL, "PROVIDER_REDIRECT_URL", "CALLBACK_URL")
l.loadEnvString(&config.Provider.LogoutURL, "PROVIDER_LOGOUT_URL", "LOGOUT_URL")
l.loadEnvString(&config.Provider.PostLogoutRedirectURI, "PROVIDER_POST_LOGOUT_URI", "POST_LOGOUT_REDIRECT_URI")
l.loadEnvStringSlice(&config.Provider.Scopes, "PROVIDER_SCOPES", "SCOPES")
l.loadEnvBool(&config.Provider.OverrideScopes, "PROVIDER_OVERRIDE_SCOPES", "OVERRIDE_SCOPES")
// Session configuration
l.loadEnvString(&config.Session.Name, "SESSION_NAME")
l.loadEnvInt(&config.Session.MaxAge, "SESSION_MAX_AGE")
l.loadEnvString(&config.Session.Secret, "SESSION_SECRET")
l.loadEnvString(&config.Session.EncryptionKey, "SESSION_ENCRYPTION_KEY")
l.loadEnvString(&config.Session.Domain, "SESSION_DOMAIN", "COOKIE_DOMAIN")
l.loadEnvBool(&config.Session.Secure, "SESSION_SECURE")
l.loadEnvBool(&config.Session.HttpOnly, "SESSION_HTTP_ONLY")
l.loadEnvString(&config.Session.SameSite, "SESSION_SAME_SITE")
// Security configuration
l.loadEnvBool(&config.Security.ForceHTTPS, "SECURITY_FORCE_HTTPS", "FORCE_HTTPS")
l.loadEnvBool(&config.Security.EnablePKCE, "SECURITY_ENABLE_PKCE", "ENABLE_PKCE")
l.loadEnvStringSlice(&config.Security.AllowedUsers, "SECURITY_ALLOWED_USERS", "ALLOWED_USERS")
l.loadEnvStringSlice(&config.Security.AllowedUserDomains, "SECURITY_ALLOWED_DOMAINS", "ALLOWED_USER_DOMAINS")
l.loadEnvStringSlice(&config.Security.AllowedRolesAndGroups, "SECURITY_ALLOWED_ROLES", "ALLOWED_ROLES_AND_GROUPS")
l.loadEnvStringSlice(&config.Security.ExcludedURLs, "SECURITY_EXCLUDED_URLS", "EXCLUDED_URLS")
// Cache configuration
l.loadEnvBool(&config.Cache.Enabled, "CACHE_ENABLED")
l.loadEnvString(&config.Cache.Type, "CACHE_TYPE")
l.loadEnvInt(&config.Cache.MaxEntries, "CACHE_MAX_ENTRIES")
// MaxEntrySize is int64, skip for now
// Rate limiting
l.loadEnvBool(&config.RateLimit.Enabled, "RATELIMIT_ENABLED")
l.loadEnvInt(&config.RateLimit.RequestsPerSecond, "RATELIMIT_RPS", "RATE_LIMIT")
l.loadEnvInt(&config.RateLimit.Burst, "RATELIMIT_BURST")
// Logging
l.loadEnvString(&config.Logging.Level, "LOGGING_LEVEL", "LOG_LEVEL")
l.loadEnvString(&config.Logging.Format, "LOGGING_FORMAT")
l.loadEnvString(&config.Logging.Output, "LOGGING_OUTPUT")
// Redis configuration (already handled by its own LoadFromEnv)
config.Redis.LoadFromEnv()
// Feature flags
features.GetManager().LoadFromEnv()
}
// Helper methods for environment variable loading
func (l *ConfigLoader) loadEnvString(target *string, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = value
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = value
return
}
}
}
func (l *ConfigLoader) loadEnvBool(target *bool, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = strings.ToLower(value) == "true" || value == "1"
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = strings.ToLower(value) == "true" || value == "1"
return
}
}
}
func (l *ConfigLoader) loadEnvInt(target *int, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
var i int
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
*target = i
return
}
}
// Try without prefix
if value := os.Getenv(key); value != "" {
var i int
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
*target = i
return
}
}
}
}
func (l *ConfigLoader) loadEnvStringSlice(target *[]string, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = splitAndTrim(value)
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = splitAndTrim(value)
return
}
}
}
func splitAndTrim(s string) []string {
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
if trimmed := strings.TrimSpace(part); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
// mergeConfigs merges two configurations, with source overriding target
func (l *ConfigLoader) mergeConfigs(target, source *UnifiedConfig) *UnifiedConfig {
if source == nil {
return target
}
if target == nil {
return source
}
// Use reflection for deep merge
l.mergeStructs(reflect.ValueOf(target).Elem(), reflect.ValueOf(source).Elem())
return target
}
// mergeStructs recursively merges two structs
func (l *ConfigLoader) mergeStructs(target, source reflect.Value) {
for i := 0; i < source.NumField(); i++ {
sourceField := source.Field(i)
targetField := target.Field(i)
// Skip if source field is zero value
if isZeroValue(sourceField) {
continue
}
switch sourceField.Kind() {
case reflect.Struct:
// Recursively merge structs
l.mergeStructs(targetField, sourceField)
case reflect.Slice:
// Replace slice if source has values
if sourceField.Len() > 0 {
targetField.Set(sourceField)
}
case reflect.Map:
// Merge maps
if !sourceField.IsNil() {
if targetField.IsNil() {
targetField.Set(reflect.MakeMap(sourceField.Type()))
}
for _, key := range sourceField.MapKeys() {
targetField.SetMapIndex(key, sourceField.MapIndex(key))
}
}
default:
// Replace value
targetField.Set(sourceField)
}
}
}
// isZeroValue checks if a reflect.Value is a zero value
func isZeroValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return v.IsNil()
case reflect.Slice, reflect.Map:
return v.IsNil() || v.Len() == 0
case reflect.Struct:
// Check if all fields are zero
for i := 0; i < v.NumField(); i++ {
if !isZeroValue(v.Field(i)) {
return false
}
}
return true
default:
zero := reflect.Zero(v.Type())
return reflect.DeepEqual(v.Interface(), zero.Interface())
}
}
// SaveToFile saves the configuration to a file
func (l *ConfigLoader) SaveToFile(config *UnifiedConfig, path string) error {
// Clean and validate path to prevent traversal attacks
cleanPath := filepath.Clean(path)
// Check for path traversal attempts
if strings.Contains(cleanPath, "..") {
return fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
}
// Ensure the path is within expected directories
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
}
ext := strings.ToLower(filepath.Ext(absPath))
var data []byte
switch ext {
case ".json":
data, err = json.MarshalIndent(config, "", " ")
case ".yaml", ".yml":
data, err = yaml.Marshal(config)
default:
return fmt.Errorf("unsupported file extension: %s", ext)
}
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create directory if it doesn't exist with secure permissions
dir := filepath.Dir(absPath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Write file with secure permissions (owner read/write only)
if err := os.WriteFile(absPath, data, 0600); err != nil {
return fmt.Errorf("failed to write config file %s: %w", absPath, err)
}
return nil
}
+292
View File
@@ -0,0 +1,292 @@
//go:build !yaegi
package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
// TestConfigLoader tests the config loader functionality
func TestConfigLoader(t *testing.T) {
loader := NewConfigLoader()
if loader == nil {
t.Fatal("NewConfigLoader should not return nil")
}
if loader.migrator == nil {
t.Error("ConfigLoader should have a migrator")
}
if loader.envPrefix != "TRAEFIKOIDC_" {
t.Errorf("Expected envPrefix to be 'TRAEFIKOIDC_', got %s", loader.envPrefix)
}
if len(loader.configPaths) == 0 {
t.Error("ConfigLoader should have default config paths")
}
}
// TestLoadFromEnv tests loading configuration from environment variables
func TestLoadFromEnv(t *testing.T) {
// Set up test environment variables
testEnvVars := map[string]string{
"TRAEFIKOIDC_PROVIDER_ISSUER_URL": "https://test.example.com",
"TRAEFIKOIDC_PROVIDER_CLIENT_ID": "test-client-id",
"TRAEFIKOIDC_PROVIDER_CLIENT_SECRET": "test-secret",
"TRAEFIKOIDC_SESSION_ENCRYPTION_KEY": "32-character-encryption-key-12345",
"TRAEFIKOIDC_SESSION_CHUNKED": "true",
"TRAEFIKOIDC_REDIS_ENABLED": "true",
"TRAEFIKOIDC_REDIS_ADDR": "redis.example.com:6379",
"TRAEFIKOIDC_SECURITY_FORCE_HTTPS": "true",
"TRAEFIKOIDC_CACHE_ENABLED": "true",
"TRAEFIKOIDC_CACHE_TYPE": "redis",
"TRAEFIKOIDC_RATELIMIT_ENABLED": "true",
"TRAEFIKOIDC_RATELIMIT_RPS": "100",
}
// Set environment variables
for key, value := range testEnvVars {
os.Setenv(key, value)
defer os.Unsetenv(key)
}
loader := NewConfigLoader()
config := &UnifiedConfig{}
loader.LoadFromEnv(config)
// Verify values were loaded
if config.Provider.IssuerURL != "https://test.example.com" {
t.Errorf("Expected IssuerURL to be 'https://test.example.com', got %s", config.Provider.IssuerURL)
}
if config.Provider.ClientID != "test-client-id" {
t.Errorf("Expected ClientID to be 'test-client-id', got %s", config.Provider.ClientID)
}
if config.Provider.ClientSecret != "test-secret" {
t.Errorf("Expected ClientSecret to be 'test-secret', got %s", config.Provider.ClientSecret)
}
if config.Session.EncryptionKey != "32-character-encryption-key-12345" {
t.Errorf("Expected EncryptionKey to be set, got %s", config.Session.EncryptionKey)
}
if !config.Security.ForceHTTPS {
t.Error("Expected ForceHTTPS to be true")
}
if !config.Cache.Enabled {
t.Error("Expected Cache to be enabled")
}
if config.Cache.Type != "redis" {
t.Errorf("Expected Cache.Type to be 'redis', got %s", config.Cache.Type)
}
if !config.RateLimit.Enabled {
t.Error("Expected RateLimit to be enabled")
}
if config.RateLimit.RequestsPerSecond != 100 {
t.Errorf("Expected RequestsPerSecond to be 100, got %d", config.RateLimit.RequestsPerSecond)
}
}
// TestSaveToFile tests saving configuration to files
func TestSaveToFile(t *testing.T) {
// Create a temporary directory for test files
tmpDir, err := os.MkdirTemp("", "config-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
loader := NewConfigLoader()
config := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "32-character-encryption-key-12345",
},
}
tests := []struct {
name string
filename string
wantErr bool
}{
{
name: "save as JSON",
filename: "config.json",
wantErr: false,
},
{
name: "save as YAML",
filename: "config.yaml",
wantErr: false,
},
{
name: "save as YML",
filename: "config.yml",
wantErr: false,
},
{
name: "unsupported extension",
filename: "config.txt",
wantErr: true,
},
{
name: "path traversal attempt",
filename: "../../../etc/config.json",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filePath := filepath.Join(tmpDir, tt.filename)
err := loader.SaveToFile(config, filePath)
if tt.wantErr {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
// Verify file was created with correct permissions
info, err := os.Stat(filePath)
if err != nil {
t.Errorf("Failed to stat saved file: %v", err)
return
}
// Check file permissions (should be 0600)
mode := info.Mode().Perm()
if mode != 0600 {
t.Errorf("Expected file permissions 0600, got %o", mode)
}
// Verify content can be read back
data, err := os.ReadFile(filePath)
if err != nil {
t.Errorf("Failed to read saved file: %v", err)
return
}
// Verify secrets are redacted
content := string(data)
if strings.Contains(content, "secret") && !strings.Contains(content, "[REDACTED]") {
t.Error("Secrets should be redacted in saved file")
}
})
}
}
// TestLoadFile tests loading configuration from files
func TestLoadFile(t *testing.T) {
// Create a temporary directory for test files
tmpDir, err := os.MkdirTemp("", "config-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)
// Test data - using old config format since unified config is not enabled by default
jsonConfig := `{
"providerURL": "https://auth.example.com",
"clientID": "test-client",
"clientSecret": "secret",
"sessionEncryptionKey": "32-character-encryption-key-12345"
}`
yamlConfig := `
providerurl: https://auth.example.com
clientid: test-client
clientsecret: secret
sessionencryptionkey: 32-character-encryption-key-12345
`
tests := []struct {
name string
filename string
content string
wantErr bool
}{
{
name: "load JSON config",
filename: "config.json",
content: jsonConfig,
wantErr: false,
},
{
name: "load YAML config",
filename: "config.yaml",
content: yamlConfig,
wantErr: false,
},
{
name: "path traversal attempt",
filename: "../../../etc/passwd",
content: "",
wantErr: true,
},
{
name: "non-existent file",
filename: "does-not-exist.json",
content: "",
wantErr: true,
},
}
loader := NewConfigLoader()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var filePath string
if tt.content != "" {
filePath = filepath.Join(tmpDir, tt.filename)
err := os.WriteFile(filePath, []byte(tt.content), 0600)
if err != nil {
t.Fatalf("Failed to write test file: %v", err)
return
}
} else {
filePath = tt.filename
}
config, err := loader.loadFile(filePath)
if tt.wantErr {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
if !os.IsNotExist(err) && !strings.Contains(err.Error(), "no such file") {
t.Errorf("Unexpected error: %v", err)
}
return
}
// Verify loaded config
if config == nil {
t.Error("Expected config to be loaded")
return
}
if config.Provider.IssuerURL != "https://auth.example.com" {
t.Errorf("Expected IssuerURL to be 'https://auth.example.com', got %s", config.Provider.IssuerURL)
}
if config.Provider.ClientID != "test-client" {
t.Errorf("Expected ClientID to be 'test-client', got %s", config.Provider.ClientID)
}
})
}
}
+169
View File
@@ -0,0 +1,169 @@
// Package config provides unified configuration management for the OIDC middleware
package config
import (
"encoding/json"
)
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
func (c UnifiedConfig) MarshalJSON() ([]byte, error) {
// Create an alias to avoid recursion
type Alias UnifiedConfig
// Create a copy with redacted sensitive fields
copy := (Alias)(c)
// Redact provider secrets
if copy.Provider.ClientSecret != "" {
copy.Provider.ClientSecret = REDACTED
}
// Redact session secrets
if copy.Session.Secret != "" {
copy.Session.Secret = REDACTED
}
if copy.Session.EncryptionKey != "" {
copy.Session.EncryptionKey = REDACTED
}
if copy.Session.SigningKey != "" {
copy.Session.SigningKey = REDACTED
}
// Redact Redis passwords
if copy.Redis.Password != "" {
copy.Redis.Password = REDACTED
}
if copy.Redis.SentinelPassword != "" {
copy.Redis.SentinelPassword = REDACTED
}
return json.Marshal(copy)
}
// MarshalJSON for ProviderConfig to redact sensitive fields
func (p ProviderConfig) MarshalJSON() ([]byte, error) {
type Alias ProviderConfig
copy := (Alias)(p)
if copy.ClientSecret != "" {
copy.ClientSecret = REDACTED
}
return json.Marshal(copy)
}
// MarshalJSON for SessionConfig to redact sensitive fields
func (s SessionConfig) MarshalJSON() ([]byte, error) {
type Alias SessionConfig
copy := (Alias)(s)
if copy.Secret != "" {
copy.Secret = REDACTED
}
if copy.EncryptionKey != "" {
copy.EncryptionKey = REDACTED
}
if copy.SigningKey != "" {
copy.SigningKey = REDACTED
}
return json.Marshal(copy)
}
// MarshalJSON for RedisConfig to redact sensitive fields
func (r RedisConfig) MarshalJSON() ([]byte, error) {
type Alias RedisConfig
copy := (Alias)(r)
if copy.Password != "" {
copy.Password = REDACTED
}
if copy.SentinelPassword != "" {
copy.SentinelPassword = REDACTED
}
return json.Marshal(copy)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
func (c UnifiedConfig) MarshalYAML() (interface{}, error) {
// Create an alias to avoid recursion
type Alias UnifiedConfig
// Create a copy with redacted sensitive fields
copy := (Alias)(c)
// Redact provider secrets
if copy.Provider.ClientSecret != "" {
copy.Provider.ClientSecret = REDACTED
}
// Redact session secrets
if copy.Session.Secret != "" {
copy.Session.Secret = REDACTED
}
if copy.Session.EncryptionKey != "" {
copy.Session.EncryptionKey = REDACTED
}
if copy.Session.SigningKey != "" {
copy.Session.SigningKey = REDACTED
}
// Redact Redis passwords
if copy.Redis.Password != "" {
copy.Redis.Password = REDACTED
}
if copy.Redis.SentinelPassword != "" {
copy.Redis.SentinelPassword = REDACTED
}
return copy, nil
}
// MarshalYAML for ProviderConfig to redact sensitive fields
func (p ProviderConfig) MarshalYAML() (interface{}, error) {
type Alias ProviderConfig
copy := (Alias)(p)
if copy.ClientSecret != "" {
copy.ClientSecret = REDACTED
}
return copy, nil
}
// MarshalYAML for SessionConfig to redact sensitive fields
func (s SessionConfig) MarshalYAML() (interface{}, error) {
type Alias SessionConfig
copy := (Alias)(s)
if copy.Secret != "" {
copy.Secret = REDACTED
}
if copy.EncryptionKey != "" {
copy.EncryptionKey = REDACTED
}
if copy.SigningKey != "" {
copy.SigningKey = REDACTED
}
return copy, nil
}
// MarshalYAML for RedisConfig to redact sensitive fields
func (r RedisConfig) MarshalYAML() (interface{}, error) {
type Alias RedisConfig
copy := (Alias)(r)
if copy.Password != "" {
copy.Password = REDACTED
}
if copy.SentinelPassword != "" {
copy.SentinelPassword = REDACTED
}
return copy, nil
}
+407
View File
@@ -0,0 +1,407 @@
// Package config provides configuration migration from old to new format
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/compat"
"github.com/lukaszraczylo/traefikoidc/internal/features"
"gopkg.in/yaml.v3"
)
// ConfigVersion represents the version of a configuration format
type ConfigVersion string
const (
// VersionLegacy represents the original config format
VersionLegacy ConfigVersion = "legacy"
// VersionUnified represents the new unified config format
VersionUnified ConfigVersion = "unified"
// CurrentVersion is the current config version
CurrentVersion ConfigVersion = VersionUnified
)
// ConfigMigrator handles migration between config versions
type ConfigMigrator struct {
compatLayer *compat.CompatibilityLayer
migrations map[ConfigVersion]MigrationFunc
}
// MigrationFunc defines a function that migrates configuration
type MigrationFunc func(data map[string]interface{}) (*UnifiedConfig, error)
// NewConfigMigrator creates a new configuration migrator
func NewConfigMigrator() *ConfigMigrator {
m := &ConfigMigrator{
compatLayer: compat.GetLayer(),
migrations: make(map[ConfigVersion]MigrationFunc),
}
// Register migration functions
m.migrations[VersionLegacy] = m.migrateLegacyToUnified
return m
}
// DetectVersion detects the version of a configuration
func (m *ConfigMigrator) DetectVersion(data []byte) ConfigVersion {
var testMap map[string]interface{}
// Try JSON first
if err := json.Unmarshal(data, &testMap); err != nil {
// Try YAML
if err := yaml.Unmarshal(data, &testMap); err != nil {
return VersionLegacy // Default to legacy if can't parse
}
}
// Check for unified config markers
if _, hasProvider := testMap["provider"]; hasProvider {
if _, hasSession := testMap["session"]; hasSession {
return VersionUnified
}
}
// Check for legacy config markers
if _, hasProviderURL := testMap["providerUrl"]; hasProviderURL {
return VersionLegacy
}
if _, hasProviderURL := testMap["ProviderURL"]; hasProviderURL {
return VersionLegacy
}
return VersionLegacy
}
// Migrate migrates configuration data to the current version
func (m *ConfigMigrator) Migrate(data []byte) (*UnifiedConfig, []string, error) {
warnings := []string{}
// Detect version
version := m.DetectVersion(data)
// If already current version, just unmarshal
if version == CurrentVersion {
var config UnifiedConfig
if err := json.Unmarshal(data, &config); err != nil {
// Try YAML
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, warnings, fmt.Errorf("failed to unmarshal unified config: %w", err)
}
}
return &config, warnings, nil
}
// Parse to generic map
var configMap map[string]interface{}
if err := json.Unmarshal(data, &configMap); err != nil {
// Try YAML
if err := yaml.Unmarshal(data, &configMap); err != nil {
return nil, warnings, fmt.Errorf("failed to unmarshal config: %w", err)
}
}
// Apply migration
migrationFunc, exists := m.migrations[version]
if !exists {
return nil, warnings, fmt.Errorf("no migration path from version %s", version)
}
config, err := migrationFunc(configMap)
if err != nil {
return nil, warnings, fmt.Errorf("migration failed: %w", err)
}
// Collect any deprecation warnings
for key := range configMap {
if warning, deprecated := m.compatLayer.CheckDeprecation(key); deprecated {
warnings = append(warnings, warning)
}
}
return config, warnings, nil
}
// migrateLegacyToUnified migrates legacy config to unified format
func (m *ConfigMigrator) migrateLegacyToUnified(data map[string]interface{}) (*UnifiedConfig, error) {
config := NewUnifiedConfig()
// Use compatibility layer for field mapping
migratedMap, warnings := m.compatLayer.MigrateMap(data)
// Log warnings
for _, warning := range warnings {
// In production, these would be logged
_ = warning
}
// Map provider configuration
if provider, ok := getNestedMap(migratedMap, "Provider"); ok {
_ = mapToStruct(provider, &config.Provider)
} else {
// Direct field mapping for legacy format
config.Provider.IssuerURL = getStringValue(data, "providerUrl", "ProviderURL")
config.Provider.ClientID = getStringValue(data, "clientId", "ClientID")
config.Provider.ClientSecret = getStringValue(data, "clientSecret", "ClientSecret")
config.Provider.RedirectURL = getStringValue(data, "callbackUrl", "CallbackURL")
config.Provider.LogoutURL = getStringValue(data, "logoutUrl", "LogoutURL")
config.Provider.PostLogoutRedirectURI = getStringValue(data, "postLogoutRedirectUri", "PostLogoutRedirectURI")
if scopes := getArrayValue(data, "scopes", "Scopes"); scopes != nil {
config.Provider.Scopes = scopes
}
config.Provider.OverrideScopes = getBoolValue(data, "overrideScopes", "OverrideScopes")
}
// Map session configuration
if session, ok := getNestedMap(migratedMap, "Session"); ok {
_ = mapToStruct(session, &config.Session)
} else {
config.Session.EncryptionKey = getStringValue(data, "sessionEncryptionKey", "SessionEncryptionKey")
config.Session.Domain = getStringValue(data, "cookieDomain", "CookieDomain")
}
// Map security configuration
if security, ok := getNestedMap(migratedMap, "Security"); ok {
_ = mapToStruct(security, &config.Security)
} else {
config.Security.ForceHTTPS = getBoolValue(data, "forceHttps", "ForceHTTPS")
config.Security.EnablePKCE = getBoolValue(data, "enablePkce", "EnablePKCE")
if users := getArrayValue(data, "allowedUsers", "AllowedUsers"); users != nil {
config.Security.AllowedUsers = users
}
if domains := getArrayValue(data, "allowedUserDomains", "AllowedUserDomains"); domains != nil {
config.Security.AllowedUserDomains = domains
}
if roles := getArrayValue(data, "allowedRolesAndGroups", "AllowedRolesAndGroups"); roles != nil {
config.Security.AllowedRolesAndGroups = roles
}
if excluded := getArrayValue(data, "excludedUrls", "ExcludedURLs"); excluded != nil {
config.Security.ExcludedURLs = excluded
}
// Handle security headers
if headers := data["securityHeaders"]; headers != nil {
// Security headers might be in old format
_ = mapToStruct(headers, &config.Security.Headers)
}
}
// Map rate limiting
if rateLimit := getIntValue(data, "rateLimit", "RateLimit"); rateLimit > 0 {
config.RateLimit.Enabled = true
config.RateLimit.RequestsPerSecond = rateLimit
config.RateLimit.Burst = rateLimit * 2 // Default burst to 2x rate
}
// Map token configuration
if refreshGrace := getIntValue(data, "refreshGracePeriodSeconds", "RefreshGracePeriodSeconds"); refreshGrace > 0 {
config.Token.RefreshGracePeriod = time.Duration(refreshGrace) * time.Second
}
// Map logging
config.Logging.Level = strings.ToLower(getStringValue(data, "logLevel", "LogLevel"))
if config.Logging.Level == "" {
config.Logging.Level = "info"
}
// Map custom headers
if headers := data["headers"]; headers != nil {
if headerList, ok := headers.([]interface{}); ok {
config.Middleware.CustomHeaders = make(map[string]string)
for _, h := range headerList {
if headerMap, ok := h.(map[string]interface{}); ok {
name := getStringFromInterface(headerMap["name"])
value := getStringFromInterface(headerMap["value"])
if name != "" {
config.Middleware.CustomHeaders[name] = value
}
}
}
}
}
// Store original data for reference
config.Legacy = data
return config, nil
}
// MigrateFile migrates a configuration file
func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
// Clean and validate path to prevent traversal attacks
cleanPath := filepath.Clean(filePath)
// Check for path traversal attempts
if strings.Contains(cleanPath, "..") {
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", filePath)
}
// Ensure the path is within expected directories
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", filePath, err)
}
// Read the file with validated path
data, err := os.ReadFile(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
config, warnings, err := m.Migrate(data)
if err != nil {
return nil, err
}
// Log warnings
for _, warning := range warnings {
fmt.Printf("Migration Warning: %s\n", warning)
}
return config, nil
}
// AutoMigrate automatically migrates config based on feature flags
func AutoMigrate(data interface{}) (*UnifiedConfig, error) {
if !features.IsUnifiedConfigEnabled() {
// Feature not enabled, return nil
return nil, nil
}
migrator := NewConfigMigrator()
// Handle different input types
switch v := data.(type) {
case []byte:
config, _, err := migrator.Migrate(v)
return config, err
case string:
config, _, err := migrator.Migrate([]byte(v))
return config, err
case *Config:
// Convert old config to unified
return FromOldConfig(v), nil
case *UnifiedConfig:
// Already unified
return v, nil
case map[string]interface{}:
// Convert map to JSON then migrate
jsonData, err := json.Marshal(v)
if err != nil {
return nil, err
}
config, _, err := migrator.Migrate(jsonData)
return config, err
default:
return nil, fmt.Errorf("unsupported config type: %T", v)
}
}
// Helper functions
func getNestedMap(m map[string]interface{}, key string) (map[string]interface{}, bool) {
if val, exists := m[key]; exists {
if mapped, ok := val.(map[string]interface{}); ok {
return mapped, true
}
}
return nil, false
}
func getStringValue(m map[string]interface{}, keys ...string) string {
for _, key := range keys {
if val, exists := m[key]; exists {
return getStringFromInterface(val)
}
}
return ""
}
func getStringFromInterface(val interface{}) string {
if val == nil {
return ""
}
switch v := val.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprintf("%v", v)
}
}
func getBoolValue(m map[string]interface{}, keys ...string) bool {
for _, key := range keys {
if val, exists := m[key]; exists {
if b, ok := val.(bool); ok {
return b
}
// Try string conversion
if s, ok := val.(string); ok {
return strings.ToLower(s) == "true"
}
}
}
return false
}
func getIntValue(m map[string]interface{}, keys ...string) int {
for _, key := range keys {
if val, exists := m[key]; exists {
switch v := val.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case string:
// Try to parse
var i int
if _, err := fmt.Sscanf(v, "%d", &i); err != nil {
// If parsing fails, return default
return 0
}
return i
}
}
}
return 0
}
func getArrayValue(m map[string]interface{}, keys ...string) []string {
for _, key := range keys {
if val, exists := m[key]; exists {
if arr, ok := val.([]interface{}); ok {
result := make([]string, 0, len(arr))
for _, item := range arr {
result = append(result, getStringFromInterface(item))
}
return result
}
if strArr, ok := val.([]string); ok {
return strArr
}
}
}
return nil
}
func mapToStruct(m interface{}, target interface{}) error {
// Simple mapping using JSON as intermediate
data, err := json.Marshal(m)
if err != nil {
return err
}
return json.Unmarshal(data, target)
}
+286
View File
@@ -0,0 +1,286 @@
// Package config provides unified configuration management for the OIDC middleware
package config
import (
"time"
)
// UnifiedConfig is the master configuration structure consolidating all config aspects
// This replaces 45 duplicate config structs across the codebase
type UnifiedConfig struct {
// Core Configuration
Provider ProviderConfig `json:"provider" yaml:"provider"`
Session SessionConfig `json:"session" yaml:"session"`
Token TokenConfig `json:"token" yaml:"token"`
Redis RedisConfig `json:"redis" yaml:"redis"`
Security SecurityConfig `json:"security" yaml:"security"`
// Middleware Configuration
Middleware MiddlewareConfig `json:"middleware" yaml:"middleware"`
Cache CacheConfig `json:"cache" yaml:"cache"`
RateLimit RateLimitConfig `json:"rateLimit" yaml:"rateLimit"`
// Operational Configuration
Logging LoggingConfig `json:"logging" yaml:"logging"`
Metrics MetricsConfig `json:"metrics" yaml:"metrics"`
Health HealthConfig `json:"health" yaml:"health"`
// Advanced Configuration
Transport TransportConfig `json:"transport" yaml:"transport"`
Pool PoolConfig `json:"pool" yaml:"pool"`
Circuit CircuitConfig `json:"circuit" yaml:"circuit"`
// Compatibility field for migration
Legacy map[string]interface{} `json:"-" yaml:"-"`
}
// ProviderConfig contains OIDC provider settings
type ProviderConfig struct {
IssuerURL string `json:"issuerURL" yaml:"issuerURL"`
ClientID string `json:"clientID" yaml:"clientID"`
ClientSecret string `json:"clientSecret" yaml:"clientSecret"`
RedirectURL string `json:"redirectURL" yaml:"redirectURL"`
LogoutURL string `json:"logoutURL" yaml:"logoutURL"`
PostLogoutRedirectURI string `json:"postLogoutRedirectURI" yaml:"postLogoutRedirectURI"`
Scopes []string `json:"scopes" yaml:"scopes"`
OverrideScopes bool `json:"overrideScopes" yaml:"overrideScopes"`
CustomClaims map[string]string `json:"customClaims" yaml:"customClaims"`
JWKCachePeriod time.Duration `json:"jwkCachePeriod" yaml:"jwkCachePeriod"`
MetadataCacheTTL time.Duration `json:"metadataCacheTTL" yaml:"metadataCacheTTL"`
Discovery bool `json:"discovery" yaml:"discovery"`
// Provider-specific endpoints
AuthorizationEndpoint string `json:"authorizationEndpoint,omitempty" yaml:"authorizationEndpoint,omitempty"`
TokenEndpoint string `json:"tokenEndpoint,omitempty" yaml:"tokenEndpoint,omitempty"`
UserInfoEndpoint string `json:"userInfoEndpoint,omitempty" yaml:"userInfoEndpoint,omitempty"`
JWKSEndpoint string `json:"jwksEndpoint,omitempty" yaml:"jwksEndpoint,omitempty"`
IntrospectEndpoint string `json:"introspectEndpoint,omitempty" yaml:"introspectEndpoint,omitempty"`
RevocationEndpoint string `json:"revocationEndpoint,omitempty" yaml:"revocationEndpoint,omitempty"`
}
// SessionConfig contains session management settings
type SessionConfig struct {
Name string `json:"name" yaml:"name"`
MaxAge int `json:"maxAge" yaml:"maxAge"`
Secret string `json:"secret" yaml:"secret"`
EncryptionKey string `json:"encryptionKey" yaml:"encryptionKey"`
SigningKey string `json:"signingKey" yaml:"signingKey"`
ChunkSize int `json:"chunkSize" yaml:"chunkSize"`
MaxChunks int `json:"maxChunks" yaml:"maxChunks"`
// Cookie settings
Domain string `json:"domain" yaml:"domain"`
Path string `json:"path" yaml:"path"`
Secure bool `json:"secure" yaml:"secure"`
HttpOnly bool `json:"httpOnly" yaml:"httpOnly"`
SameSite string `json:"sameSite" yaml:"sameSite"`
// Storage settings
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis", "cookie"
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
}
// TokenConfig contains token handling settings
type TokenConfig struct {
AccessTokenTTL time.Duration `json:"accessTokenTTL" yaml:"accessTokenTTL"`
RefreshTokenTTL time.Duration `json:"refreshTokenTTL" yaml:"refreshTokenTTL"`
RefreshGracePeriod time.Duration `json:"refreshGracePeriod" yaml:"refreshGracePeriod"`
ValidationMode string `json:"validationMode" yaml:"validationMode"` // "jwt", "introspect", "hybrid"
IntrospectURL string `json:"introspectURL" yaml:"introspectURL"`
// Token caching
CacheEnabled bool `json:"cacheEnabled" yaml:"cacheEnabled"`
CacheTTL time.Duration `json:"cacheTTL" yaml:"cacheTTL"`
CacheNegativeTTL time.Duration `json:"cacheNegativeTTL" yaml:"cacheNegativeTTL"`
// Token validation
ValidateSignature bool `json:"validateSignature" yaml:"validateSignature"`
ValidateExpiry bool `json:"validateExpiry" yaml:"validateExpiry"`
ValidateAudience bool `json:"validateAudience" yaml:"validateAudience"`
ValidateIssuer bool `json:"validateIssuer" yaml:"validateIssuer"`
RequiredClaims []string `json:"requiredClaims" yaml:"requiredClaims"`
ClockSkew time.Duration `json:"clockSkew" yaml:"clockSkew"`
}
// SecurityConfig contains security-related settings
type SecurityConfig struct {
ForceHTTPS bool `json:"forceHTTPS" yaml:"forceHTTPS"`
EnablePKCE bool `json:"enablePKCE" yaml:"enablePKCE"`
AllowedUsers []string `json:"allowedUsers" yaml:"allowedUsers"`
AllowedUserDomains []string `json:"allowedUserDomains" yaml:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups" yaml:"allowedRolesAndGroups"`
ExcludedURLs []string `json:"excludedURLs" yaml:"excludedURLs"`
Headers *SecurityHeadersConfig `json:"headers" yaml:"headers"`
// CSRF protection
CSRFProtection bool `json:"csrfProtection" yaml:"csrfProtection"`
CSRFTokenName string `json:"csrfTokenName" yaml:"csrfTokenName"`
CSRFTokenTTL time.Duration `json:"csrfTokenTTL" yaml:"csrfTokenTTL"`
// Additional security
MaxLoginAttempts int `json:"maxLoginAttempts" yaml:"maxLoginAttempts"`
LockoutDuration time.Duration `json:"lockoutDuration" yaml:"lockoutDuration"`
RequireMFA bool `json:"requireMFA" yaml:"requireMFA"`
}
// MiddlewareConfig contains middleware-specific settings
type MiddlewareConfig struct {
Priority int `json:"priority" yaml:"priority"`
SkipPaths []string `json:"skipPaths" yaml:"skipPaths"`
RequirePaths []string `json:"requirePaths" yaml:"requirePaths"`
PassthroughMode bool `json:"passthroughMode" yaml:"passthroughMode"`
// Request handling
MaxRequestSize int64 `json:"maxRequestSize" yaml:"maxRequestSize"`
RequestTimeout time.Duration `json:"requestTimeout" yaml:"requestTimeout"`
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
// Response handling
CustomHeaders map[string]string `json:"customHeaders" yaml:"customHeaders"`
RemoveHeaders []string `json:"removeHeaders" yaml:"removeHeaders"`
}
// CacheConfig contains cache configuration
type CacheConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Type string `json:"type" yaml:"type"` // "memory", "redis", "hybrid"
DefaultTTL time.Duration `json:"defaultTTL" yaml:"defaultTTL"`
MaxEntries int `json:"maxEntries" yaml:"maxEntries"`
MaxEntrySize int64 `json:"maxEntrySize" yaml:"maxEntrySize"`
EvictionPolicy string `json:"evictionPolicy" yaml:"evictionPolicy"` // "lru", "lfu", "fifo"
// Memory cache settings
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
// Distributed cache settings
Namespace string `json:"namespace" yaml:"namespace"`
Compression bool `json:"compression" yaml:"compression"`
Serialization string `json:"serialization" yaml:"serialization"` // "json", "msgpack", "protobuf"
}
// RateLimitConfig contains rate limiting configuration
type RateLimitConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
RequestsPerSecond int `json:"requestsPerSecond" yaml:"requestsPerSecond"`
Burst int `json:"burst" yaml:"burst"`
// Rate limit storage
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis"
WindowDuration time.Duration `json:"windowDuration" yaml:"windowDuration"`
// Rate limit keys
KeyType string `json:"keyType" yaml:"keyType"` // "ip", "user", "token", "custom"
CustomKeyFunc string `json:"customKeyFunc" yaml:"customKeyFunc"`
// Whitelisting
WhitelistIPs []string `json:"whitelistIPs" yaml:"whitelistIPs"`
WhitelistUsers []string `json:"whitelistUsers" yaml:"whitelistUsers"`
}
// LoggingConfig contains logging configuration
type LoggingConfig struct {
Level string `json:"level" yaml:"level"` // "debug", "info", "warn", "error"
Format string `json:"format" yaml:"format"` // "json", "text", "structured"
Output string `json:"output" yaml:"output"` // "stdout", "stderr", "file"
FilePath string `json:"filePath" yaml:"filePath"`
// Log filtering
FilterSensitive bool `json:"filterSensitive" yaml:"filterSensitive"`
MaskFields []string `json:"maskFields" yaml:"maskFields"`
// Performance
BufferSize int `json:"bufferSize" yaml:"bufferSize"`
FlushInterval time.Duration `json:"flushInterval" yaml:"flushInterval"`
// Audit logging
AuditEnabled bool `json:"auditEnabled" yaml:"auditEnabled"`
AuditEvents []string `json:"auditEvents" yaml:"auditEvents"`
}
// MetricsConfig contains metrics collection configuration
type MetricsConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Provider string `json:"provider" yaml:"provider"` // "prometheus", "statsd", "otlp"
Endpoint string `json:"endpoint" yaml:"endpoint"`
Namespace string `json:"namespace" yaml:"namespace"`
Subsystem string `json:"subsystem" yaml:"subsystem"`
// Collection settings
CollectInterval time.Duration `json:"collectInterval" yaml:"collectInterval"`
Histograms bool `json:"histograms" yaml:"histograms"`
// Custom labels
Labels map[string]string `json:"labels" yaml:"labels"`
}
// HealthConfig contains health check configuration
type HealthConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Path string `json:"path" yaml:"path"`
CheckInterval time.Duration `json:"checkInterval" yaml:"checkInterval"`
Timeout time.Duration `json:"timeout" yaml:"timeout"`
// Checks to perform
CheckProvider bool `json:"checkProvider" yaml:"checkProvider"`
CheckRedis bool `json:"checkRedis" yaml:"checkRedis"`
CheckCache bool `json:"checkCache" yaml:"checkCache"`
// Thresholds
MaxLatency time.Duration `json:"maxLatency" yaml:"maxLatency"`
MinMemory int64 `json:"minMemory" yaml:"minMemory"`
}
// TransportConfig contains HTTP transport configuration
type TransportConfig struct {
MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"`
MaxIdleConnsPerHost int `json:"maxIdleConnsPerHost" yaml:"maxIdleConnsPerHost"`
MaxConnsPerHost int `json:"maxConnsPerHost" yaml:"maxConnsPerHost"`
IdleConnTimeout time.Duration `json:"idleConnTimeout" yaml:"idleConnTimeout"`
TLSHandshakeTimeout time.Duration `json:"tlsHandshakeTimeout" yaml:"tlsHandshakeTimeout"`
ExpectContinueTimeout time.Duration `json:"expectContinueTimeout" yaml:"expectContinueTimeout"`
ResponseHeaderTimeout time.Duration `json:"responseHeaderTimeout" yaml:"responseHeaderTimeout"`
DisableKeepAlives bool `json:"disableKeepAlives" yaml:"disableKeepAlives"`
DisableCompression bool `json:"disableCompression" yaml:"disableCompression"`
// TLS configuration
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify" yaml:"tlsInsecureSkipVerify"`
TLSMinVersion string `json:"tlsMinVersion" yaml:"tlsMinVersion"`
TLSCipherSuites []string `json:"tlsCipherSuites" yaml:"tlsCipherSuites"`
// Proxy settings
ProxyURL string `json:"proxyURL" yaml:"proxyURL"`
NoProxy []string `json:"noProxy" yaml:"noProxy"`
}
// PoolConfig contains connection pool configuration
type PoolConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Size int `json:"size" yaml:"size"`
MinSize int `json:"minSize" yaml:"minSize"`
MaxSize int `json:"maxSize" yaml:"maxSize"`
MaxAge time.Duration `json:"maxAge" yaml:"maxAge"`
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
WaitTimeout time.Duration `json:"waitTimeout" yaml:"waitTimeout"`
// Health checking
HealthCheckInterval time.Duration `json:"healthCheckInterval" yaml:"healthCheckInterval"`
MaxRetries int `json:"maxRetries" yaml:"maxRetries"`
}
// CircuitConfig contains circuit breaker configuration
type CircuitConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
MaxRequests uint32 `json:"maxRequests" yaml:"maxRequests"`
Interval time.Duration `json:"interval" yaml:"interval"`
Timeout time.Duration `json:"timeout" yaml:"timeout"`
ConsecutiveFailures uint32 `json:"consecutiveFailures" yaml:"consecutiveFailures"`
FailureRatio float64 `json:"failureRatio" yaml:"failureRatio"`
// Circuit states
OnOpen string `json:"onOpen" yaml:"onOpen"` // "reject", "fallback", "passthrough"
OnHalfOpen string `json:"onHalfOpen" yaml:"onHalfOpen"`
// Monitoring
MetricsEnabled bool `json:"metricsEnabled" yaml:"metricsEnabled"`
LogStateChanges bool `json:"logStateChanges" yaml:"logStateChanges"`
}
+263
View File
@@ -0,0 +1,263 @@
//go:build !yaegi
package config
import (
"encoding/json"
"strings"
"testing"
"gopkg.in/yaml.v3"
)
// TestUnifiedConfigJSONMarshalling tests JSON marshalling with secret redaction
func TestUnifiedConfigJSONMarshalling(t *testing.T) {
config := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "super-secret-value",
},
Session: SessionConfig{
Secret: "session-secret",
EncryptionKey: "32-character-encryption-key-here",
SigningKey: "signing-key-secret",
},
Redis: RedisConfig{
Password: "redis-password",
SentinelPassword: "sentinel-password",
},
}
// Marshal to JSON
jsonBytes, err := json.Marshal(config)
if err != nil {
t.Fatalf("Failed to marshal config to JSON: %v", err)
}
jsonStr := string(jsonBytes)
// Verify secrets are redacted
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
t.Error("ClientSecret should be redacted in JSON output")
}
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
t.Error("Session.Secret should be redacted in JSON output")
}
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
t.Error("Session.EncryptionKey should be redacted in JSON output")
}
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
t.Error("Session.SigningKey should be redacted in JSON output")
}
if !contains(jsonStr, `"password":"[REDACTED]"`) {
t.Error("Redis.Password should be redacted in JSON output")
}
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
t.Error("Redis.SentinelPassword should be redacted in JSON output")
}
// Verify non-secret fields are preserved
if !contains(jsonStr, `"issuerURL":"https://auth.example.com"`) {
t.Error("IssuerURL should be preserved in JSON output")
}
if !contains(jsonStr, `"clientID":"test-client"`) {
t.Error("ClientID should be preserved in JSON output")
}
}
// TestUnifiedConfigYAMLMarshalling tests YAML marshalling with secret redaction
func TestUnifiedConfigYAMLMarshalling(t *testing.T) {
config := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "super-secret-value",
},
Session: SessionConfig{
Secret: "session-secret",
EncryptionKey: "32-character-encryption-key-here",
SigningKey: "signing-key-secret",
},
Redis: RedisConfig{
Password: "redis-password",
SentinelPassword: "sentinel-password",
},
}
// Marshal to YAML
yamlBytes, err := yaml.Marshal(config)
if err != nil {
t.Fatalf("Failed to marshal config to YAML: %v", err)
}
yamlStr := string(yamlBytes)
// Verify secrets are redacted
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
t.Error("ClientSecret should be redacted in YAML output")
}
if !contains(yamlStr, "secret: '[REDACTED]'") {
t.Error("Session.Secret should be redacted in YAML output")
}
if !contains(yamlStr, "encryptionKey: '[REDACTED]'") {
t.Error("Session.EncryptionKey should be redacted in YAML output")
}
if !contains(yamlStr, "signingKey: '[REDACTED]'") {
t.Error("Session.SigningKey should be redacted in YAML output")
}
if !contains(yamlStr, "password: '[REDACTED]'") {
t.Error("Redis.Password should be redacted in YAML output")
}
if !contains(yamlStr, "sentinelPassword: '[REDACTED]'") {
t.Error("Redis.SentinelPassword should be redacted in YAML output")
}
// Verify non-secret fields are preserved
if !contains(yamlStr, "issuerURL: https://auth.example.com") {
t.Error("IssuerURL should be preserved in YAML output")
}
if !contains(yamlStr, "clientID: test-client") {
t.Error("ClientID should be preserved in YAML output")
}
}
// TestProviderConfigMarshalling tests individual struct marshalling
func TestProviderConfigMarshalling(t *testing.T) {
provider := ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "super-secret-value",
}
// Test JSON marshalling
jsonBytes, err := json.Marshal(provider)
if err != nil {
t.Fatalf("Failed to marshal ProviderConfig to JSON: %v", err)
}
jsonStr := string(jsonBytes)
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
t.Error("ClientSecret should be redacted in JSON output")
}
if !contains(jsonStr, `"clientID":"test-client"`) {
t.Error("ClientID should be preserved in JSON output")
}
// Test YAML marshalling
yamlBytes, err := yaml.Marshal(provider)
if err != nil {
t.Fatalf("Failed to marshal ProviderConfig to YAML: %v", err)
}
yamlStr := string(yamlBytes)
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
t.Error("ClientSecret should be redacted in YAML output")
}
if !contains(yamlStr, "clientID: test-client") {
t.Error("ClientID should be preserved in YAML output")
}
}
// TestSessionConfigMarshalling tests session config marshalling
func TestSessionConfigMarshalling(t *testing.T) {
session := SessionConfig{
Name: "session-cookie",
Secret: "session-secret",
EncryptionKey: "32-character-encryption-key-here",
SigningKey: "signing-key-secret",
Domain: "example.com",
Secure: true,
}
// Test JSON marshalling
jsonBytes, err := json.Marshal(session)
if err != nil {
t.Fatalf("Failed to marshal SessionConfig to JSON: %v", err)
}
jsonStr := string(jsonBytes)
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
t.Error("Secret should be redacted in JSON output")
}
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
t.Error("EncryptionKey should be redacted in JSON output")
}
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
t.Error("SigningKey should be redacted in JSON output")
}
if !contains(jsonStr, `"name":"session-cookie"`) {
t.Error("Name should be preserved in JSON output")
}
if !contains(jsonStr, `"domain":"example.com"`) {
t.Error("Domain should be preserved in JSON output")
}
}
// TestRedisConfigMarshalling tests Redis config marshalling
func TestRedisConfigMarshalling(t *testing.T) {
redis := RedisConfig{
Enabled: true,
Mode: RedisModeCluster,
Password: "redis-password",
SentinelPassword: "sentinel-password",
Addr: "localhost:6379",
DB: 1,
}
// Test JSON marshalling
jsonBytes, err := json.Marshal(redis)
if err != nil {
t.Fatalf("Failed to marshal RedisConfig to JSON: %v", err)
}
jsonStr := string(jsonBytes)
if !contains(jsonStr, `"password":"[REDACTED]"`) {
t.Error("Password should be redacted in JSON output")
}
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
t.Error("SentinelPassword should be redacted in JSON output")
}
if !contains(jsonStr, `"addr":"localhost:6379"`) {
t.Error("Addr should be preserved in JSON output")
}
if !contains(jsonStr, `"db":1`) {
t.Error("DB should be preserved in JSON output")
}
}
// TestEmptySecretsNotRedacted tests that empty secrets are not shown as redacted
func TestEmptySecretsNotRedacted(t *testing.T) {
config := &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "", // Empty secret
},
Session: SessionConfig{
Secret: "", // Empty secret
EncryptionKey: "", // Empty secret
},
Redis: RedisConfig{
Password: "", // Empty secret
},
}
// Marshal to JSON
jsonBytes, err := json.Marshal(config)
if err != nil {
t.Fatalf("Failed to marshal config to JSON: %v", err)
}
jsonStr := string(jsonBytes)
// Verify empty secrets are not shown as redacted
if contains(jsonStr, "[REDACTED]") {
t.Error("Empty secrets should not be shown as [REDACTED]")
}
}
// Helper function to check if string contains substring
func contains(s, substr string) bool {
return strings.Contains(s, substr)
}
+652
View File
@@ -0,0 +1,652 @@
// Package config provides validation for unified configuration
package config
import (
"fmt"
"net/url"
"regexp"
"strings"
"time"
)
// ValidationError represents a configuration validation error
type ValidationError struct {
Field string
Message string
Value interface{}
}
// Error implements the error interface
func (e *ValidationError) Error() string {
if e.Value != nil {
return fmt.Sprintf("config validation error: %s: %s (value: %v)", e.Field, e.Message, e.Value)
}
return fmt.Sprintf("config validation error: %s: %s", e.Field, e.Message)
}
// ValidationErrors represents multiple validation errors
type ValidationErrors []ValidationError
// Error implements the error interface
func (e ValidationErrors) Error() string {
if len(e) == 0 {
return ""
}
var messages []string
for _, err := range e {
messages = append(messages, err.Error())
}
return strings.Join(messages, "; ")
}
// Validate performs comprehensive validation on the unified configuration
func (c *UnifiedConfig) Validate() error {
var errors ValidationErrors
// Validate Provider configuration
if err := c.validateProvider(); err != nil {
errors = append(errors, err...)
}
// Validate Session configuration
if err := c.validateSession(); err != nil {
errors = append(errors, err...)
}
// Validate Token configuration
if err := c.validateToken(); err != nil {
errors = append(errors, err...)
}
// Validate Redis configuration (uses existing validation)
if err := c.Redis.Validate(); err != nil {
errors = append(errors, ValidationError{
Field: "Redis",
Message: err.Error(),
})
}
// Validate Security configuration
if err := c.validateSecurity(); err != nil {
errors = append(errors, err...)
}
// Validate Middleware configuration
if err := c.validateMiddleware(); err != nil {
errors = append(errors, err...)
}
// Validate Cache configuration
if err := c.validateCache(); err != nil {
errors = append(errors, err...)
}
// Validate RateLimit configuration
if err := c.validateRateLimit(); err != nil {
errors = append(errors, err...)
}
// Validate Logging configuration
if err := c.validateLogging(); err != nil {
errors = append(errors, err...)
}
// Validate Metrics configuration
if err := c.validateMetrics(); err != nil {
errors = append(errors, err...)
}
// Validate Transport configuration
if err := c.validateTransport(); err != nil {
errors = append(errors, err...)
}
// Validate Circuit configuration
if err := c.validateCircuit(); err != nil {
errors = append(errors, err...)
}
if len(errors) > 0 {
return errors
}
return nil
}
// validateProvider validates provider configuration
func (c *UnifiedConfig) validateProvider() ValidationErrors {
var errors ValidationErrors
// IssuerURL is required and must be a valid URL
if c.Provider.IssuerURL == "" {
errors = append(errors, ValidationError{
Field: "Provider.IssuerURL",
Message: "issuer URL is required",
})
} else if _, err := url.Parse(c.Provider.IssuerURL); err != nil {
errors = append(errors, ValidationError{
Field: "Provider.IssuerURL",
Message: "invalid issuer URL",
Value: c.Provider.IssuerURL,
})
}
// ClientID is required
if c.Provider.ClientID == "" {
errors = append(errors, ValidationError{
Field: "Provider.ClientID",
Message: "client ID is required",
})
}
// ClientSecret is required (except for public clients with PKCE)
if c.Provider.ClientSecret == "" && !c.Security.EnablePKCE {
errors = append(errors, ValidationError{
Field: "Provider.ClientSecret",
Message: "client secret is required (or enable PKCE for public clients)",
})
}
// RedirectURL must be valid if provided
if c.Provider.RedirectURL != "" {
if _, err := url.Parse(c.Provider.RedirectURL); err != nil {
errors = append(errors, ValidationError{
Field: "Provider.RedirectURL",
Message: "invalid redirect URL",
Value: c.Provider.RedirectURL,
})
}
}
// Scopes must include 'openid' for OIDC
hasOpenID := false
for _, scope := range c.Provider.Scopes {
if scope == "openid" {
hasOpenID = true
break
}
}
if !hasOpenID && !c.Provider.OverrideScopes {
errors = append(errors, ValidationError{
Field: "Provider.Scopes",
Message: "scopes must include 'openid' for OIDC",
Value: c.Provider.Scopes,
})
}
// JWK cache period must be positive
if c.Provider.JWKCachePeriod < 0 {
errors = append(errors, ValidationError{
Field: "Provider.JWKCachePeriod",
Message: "JWK cache period must be positive",
Value: c.Provider.JWKCachePeriod,
})
}
return errors
}
// validateSession validates session configuration
func (c *UnifiedConfig) validateSession() ValidationErrors {
var errors ValidationErrors
// Session name must not be empty
if c.Session.Name == "" {
errors = append(errors, ValidationError{
Field: "Session.Name",
Message: "session name is required",
})
}
// Session secret or encryption key is required
if c.Session.Secret == "" && c.Session.EncryptionKey == "" {
errors = append(errors, ValidationError{
Field: "Session",
Message: "either session secret or encryption key is required",
})
}
// Encryption key must be at least 32 bytes for security
if c.Session.EncryptionKey != "" && len(c.Session.EncryptionKey) < 32 {
errors = append(errors, ValidationError{
Field: "Session.EncryptionKey",
Message: "encryption key must be at least 32 characters for proper security",
Value: len(c.Session.EncryptionKey),
})
}
// ChunkSize must be reasonable (between 1KB and 10KB)
if c.Session.ChunkSize < 1000 || c.Session.ChunkSize > 10000 {
errors = append(errors, ValidationError{
Field: "Session.ChunkSize",
Message: "chunk size must be between 1000 and 10000 bytes",
Value: c.Session.ChunkSize,
})
}
// MaxChunks must be reasonable (between 1 and 100)
if c.Session.MaxChunks < 1 || c.Session.MaxChunks > 100 {
errors = append(errors, ValidationError{
Field: "Session.MaxChunks",
Message: "max chunks must be between 1 and 100",
Value: c.Session.MaxChunks,
})
}
// SameSite must be valid
validSameSite := map[string]bool{
"": true,
"Lax": true,
"Strict": true,
"None": true,
}
if !validSameSite[c.Session.SameSite] {
errors = append(errors, ValidationError{
Field: "Session.SameSite",
Message: "invalid SameSite value (must be Lax, Strict, or None)",
Value: c.Session.SameSite,
})
}
// StorageType must be valid
validStorage := map[string]bool{
"memory": true,
"redis": true,
"cookie": true,
}
if !validStorage[c.Session.StorageType] {
errors = append(errors, ValidationError{
Field: "Session.StorageType",
Message: "invalid storage type (must be memory, redis, or cookie)",
Value: c.Session.StorageType,
})
}
return errors
}
// validateToken validates token configuration
func (c *UnifiedConfig) validateToken() ValidationErrors {
var errors ValidationErrors
// Token TTLs must be positive
if c.Token.AccessTokenTTL <= 0 {
errors = append(errors, ValidationError{
Field: "Token.AccessTokenTTL",
Message: "access token TTL must be positive",
Value: c.Token.AccessTokenTTL,
})
}
if c.Token.RefreshTokenTTL <= 0 {
errors = append(errors, ValidationError{
Field: "Token.RefreshTokenTTL",
Message: "refresh token TTL must be positive",
Value: c.Token.RefreshTokenTTL,
})
}
// Validation mode must be valid
validModes := map[string]bool{
"jwt": true,
"introspect": true,
"hybrid": true,
}
if !validModes[c.Token.ValidationMode] {
errors = append(errors, ValidationError{
Field: "Token.ValidationMode",
Message: "invalid validation mode (must be jwt, introspect, or hybrid)",
Value: c.Token.ValidationMode,
})
}
// Introspect URL required for introspect or hybrid mode
if (c.Token.ValidationMode == "introspect" || c.Token.ValidationMode == "hybrid") && c.Token.IntrospectURL == "" {
errors = append(errors, ValidationError{
Field: "Token.IntrospectURL",
Message: "introspect URL is required for introspect or hybrid validation mode",
})
}
// Clock skew must be reasonable (0 to 10 minutes)
if c.Token.ClockSkew < 0 || c.Token.ClockSkew > 10*time.Minute {
errors = append(errors, ValidationError{
Field: "Token.ClockSkew",
Message: "clock skew must be between 0 and 10 minutes",
Value: c.Token.ClockSkew,
})
}
return errors
}
// validateSecurity validates security configuration
func (c *UnifiedConfig) validateSecurity() ValidationErrors {
var errors ValidationErrors
// Validate allowed user domains are valid domains
domainRegex := regexp.MustCompile(`^([a-zA-Z0-9-]+\.)*[a-zA-Z0-9-]+\.[a-zA-Z]{2,}$`)
for _, domain := range c.Security.AllowedUserDomains {
if !domainRegex.MatchString(domain) {
errors = append(errors, ValidationError{
Field: "Security.AllowedUserDomains",
Message: "invalid domain format",
Value: domain,
})
}
}
// Max login attempts must be reasonable
if c.Security.MaxLoginAttempts < 0 || c.Security.MaxLoginAttempts > 100 {
errors = append(errors, ValidationError{
Field: "Security.MaxLoginAttempts",
Message: "max login attempts must be between 0 and 100",
Value: c.Security.MaxLoginAttempts,
})
}
// Lockout duration must be reasonable
if c.Security.LockoutDuration < 0 || c.Security.LockoutDuration > 24*time.Hour {
errors = append(errors, ValidationError{
Field: "Security.LockoutDuration",
Message: "lockout duration must be between 0 and 24 hours",
Value: c.Security.LockoutDuration,
})
}
return errors
}
// validateMiddleware validates middleware configuration
func (c *UnifiedConfig) validateMiddleware() ValidationErrors {
var errors ValidationErrors
// Max request size must be reasonable (1KB to 100MB)
if c.Middleware.MaxRequestSize < 1024 || c.Middleware.MaxRequestSize > 100*1024*1024 {
errors = append(errors, ValidationError{
Field: "Middleware.MaxRequestSize",
Message: "max request size must be between 1KB and 100MB",
Value: c.Middleware.MaxRequestSize,
})
}
// Request timeout must be reasonable
if c.Middleware.RequestTimeout < time.Second || c.Middleware.RequestTimeout > 5*time.Minute {
errors = append(errors, ValidationError{
Field: "Middleware.RequestTimeout",
Message: "request timeout must be between 1 second and 5 minutes",
Value: c.Middleware.RequestTimeout,
})
}
return errors
}
// validateCache validates cache configuration
func (c *UnifiedConfig) validateCache() ValidationErrors {
var errors ValidationErrors
if !c.Cache.Enabled {
return errors
}
// Cache type must be valid
validTypes := map[string]bool{
"memory": true,
"redis": true,
"hybrid": true,
}
if !validTypes[c.Cache.Type] {
errors = append(errors, ValidationError{
Field: "Cache.Type",
Message: "invalid cache type (must be memory, redis, or hybrid)",
Value: c.Cache.Type,
})
}
// Max entries must be reasonable
if c.Cache.MaxEntries < 10 || c.Cache.MaxEntries > 1000000 {
errors = append(errors, ValidationError{
Field: "Cache.MaxEntries",
Message: "max entries must be between 10 and 1000000",
Value: c.Cache.MaxEntries,
})
}
// Eviction policy must be valid
validEviction := map[string]bool{
"lru": true,
"lfu": true,
"fifo": true,
}
if !validEviction[c.Cache.EvictionPolicy] {
errors = append(errors, ValidationError{
Field: "Cache.EvictionPolicy",
Message: "invalid eviction policy (must be lru, lfu, or fifo)",
Value: c.Cache.EvictionPolicy,
})
}
return errors
}
// validateRateLimit validates rate limiting configuration
func (c *UnifiedConfig) validateRateLimit() ValidationErrors {
var errors ValidationErrors
if !c.RateLimit.Enabled {
return errors
}
// Requests per second must be reasonable
if c.RateLimit.RequestsPerSecond < 1 || c.RateLimit.RequestsPerSecond > 10000 {
errors = append(errors, ValidationError{
Field: "RateLimit.RequestsPerSecond",
Message: "requests per second must be between 1 and 10000",
Value: c.RateLimit.RequestsPerSecond,
})
}
// Burst must be at least as large as requests per second
if c.RateLimit.Burst < c.RateLimit.RequestsPerSecond {
errors = append(errors, ValidationError{
Field: "RateLimit.Burst",
Message: "burst must be at least as large as requests per second",
Value: c.RateLimit.Burst,
})
}
// Key type must be valid
validKeyTypes := map[string]bool{
"ip": true,
"user": true,
"token": true,
"custom": true,
}
if !validKeyTypes[c.RateLimit.KeyType] {
errors = append(errors, ValidationError{
Field: "RateLimit.KeyType",
Message: "invalid key type (must be ip, user, token, or custom)",
Value: c.RateLimit.KeyType,
})
}
return errors
}
// validateLogging validates logging configuration
func (c *UnifiedConfig) validateLogging() ValidationErrors {
var errors ValidationErrors
// Log level must be valid
validLevels := map[string]bool{
"debug": true,
"info": true,
"warn": true,
"error": true,
}
if !validLevels[c.Logging.Level] {
errors = append(errors, ValidationError{
Field: "Logging.Level",
Message: "invalid log level (must be debug, info, warn, or error)",
Value: c.Logging.Level,
})
}
// Format must be valid
validFormats := map[string]bool{
"json": true,
"text": true,
"structured": true,
}
if !validFormats[c.Logging.Format] {
errors = append(errors, ValidationError{
Field: "Logging.Format",
Message: "invalid log format (must be json, text, or structured)",
Value: c.Logging.Format,
})
}
// Output must be valid
validOutputs := map[string]bool{
"stdout": true,
"stderr": true,
"file": true,
}
if !validOutputs[c.Logging.Output] {
errors = append(errors, ValidationError{
Field: "Logging.Output",
Message: "invalid log output (must be stdout, stderr, or file)",
Value: c.Logging.Output,
})
}
// File path required if output is file
if c.Logging.Output == "file" && c.Logging.FilePath == "" {
errors = append(errors, ValidationError{
Field: "Logging.FilePath",
Message: "file path is required when output is 'file'",
})
}
return errors
}
// validateMetrics validates metrics configuration
func (c *UnifiedConfig) validateMetrics() ValidationErrors {
var errors ValidationErrors
if !c.Metrics.Enabled {
return errors
}
// Provider must be valid
validProviders := map[string]bool{
"prometheus": true,
"statsd": true,
"otlp": true,
}
if !validProviders[c.Metrics.Provider] {
errors = append(errors, ValidationError{
Field: "Metrics.Provider",
Message: "invalid metrics provider (must be prometheus, statsd, or otlp)",
Value: c.Metrics.Provider,
})
}
// Endpoint required for some providers
if (c.Metrics.Provider == "statsd" || c.Metrics.Provider == "otlp") && c.Metrics.Endpoint == "" {
errors = append(errors, ValidationError{
Field: "Metrics.Endpoint",
Message: fmt.Sprintf("endpoint is required for %s provider", c.Metrics.Provider),
})
}
return errors
}
// validateTransport validates transport configuration
func (c *UnifiedConfig) validateTransport() ValidationErrors {
var errors ValidationErrors
// Max connections must be reasonable
if c.Transport.MaxIdleConns < 0 || c.Transport.MaxIdleConns > 10000 {
errors = append(errors, ValidationError{
Field: "Transport.MaxIdleConns",
Message: "max idle connections must be between 0 and 10000",
Value: c.Transport.MaxIdleConns,
})
}
// TLS min version must be valid
validTLSVersions := map[string]bool{
"TLS1.0": true,
"TLS1.1": true,
"TLS1.2": true,
"TLS1.3": true,
}
if c.Transport.TLSMinVersion != "" && !validTLSVersions[c.Transport.TLSMinVersion] {
errors = append(errors, ValidationError{
Field: "Transport.TLSMinVersion",
Message: "invalid TLS min version (must be TLS1.0, TLS1.1, TLS1.2, or TLS1.3)",
Value: c.Transport.TLSMinVersion,
})
}
// Proxy URL must be valid if provided
if c.Transport.ProxyURL != "" {
if _, err := url.Parse(c.Transport.ProxyURL); err != nil {
errors = append(errors, ValidationError{
Field: "Transport.ProxyURL",
Message: "invalid proxy URL",
Value: c.Transport.ProxyURL,
})
}
}
return errors
}
// validateCircuit validates circuit breaker configuration
func (c *UnifiedConfig) validateCircuit() ValidationErrors {
var errors ValidationErrors
if !c.Circuit.Enabled {
return errors
}
// Consecutive failures must be reasonable
if c.Circuit.ConsecutiveFailures < 1 || c.Circuit.ConsecutiveFailures > 100 {
errors = append(errors, ValidationError{
Field: "Circuit.ConsecutiveFailures",
Message: "consecutive failures must be between 1 and 100",
Value: c.Circuit.ConsecutiveFailures,
})
}
// Failure ratio must be between 0 and 1
if c.Circuit.FailureRatio < 0 || c.Circuit.FailureRatio > 1 {
errors = append(errors, ValidationError{
Field: "Circuit.FailureRatio",
Message: "failure ratio must be between 0 and 1",
Value: c.Circuit.FailureRatio,
})
}
// OnOpen action must be valid
validActions := map[string]bool{
"reject": true,
"fallback": true,
"passthrough": true,
}
if !validActions[c.Circuit.OnOpen] {
errors = append(errors, ValidationError{
Field: "Circuit.OnOpen",
Message: "invalid OnOpen action (must be reject, fallback, or passthrough)",
Value: c.Circuit.OnOpen,
})
}
return errors
}
+340
View File
@@ -0,0 +1,340 @@
//go:build !yaegi
package config
import (
"strings"
"testing"
"time"
)
// TestValidateUnifiedConfig tests the validation of UnifiedConfig
func TestValidateUnifiedConfig(t *testing.T) {
tests := []struct {
name string
config *UnifiedConfig
expectError bool
errorField string
}{
{
name: "valid config with minimum requirements",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
Scopes: []string{"openid", "profile", "email"},
},
Session: SessionConfig{
Name: "oidc_session",
EncryptionKey: "this-is-a-32-character-key-12345",
ChunkSize: 4000,
MaxChunks: 5,
StorageType: "cookie",
},
Token: TokenConfig{
AccessTokenTTL: time.Hour,
RefreshTokenTTL: 24 * time.Hour,
ValidationMode: "jwt",
},
Middleware: MiddlewareConfig{
MaxRequestSize: 10 * 1024 * 1024,
RequestTimeout: 30 * time.Second,
},
Logging: LoggingConfig{
Level: "info",
Format: "json",
Output: "stdout",
},
},
expectError: false,
},
{
name: "missing provider URL",
config: &UnifiedConfig{
Provider: ProviderConfig{
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
},
},
expectError: true,
errorField: "Provider.IssuerURL",
},
{
name: "missing client ID",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
},
},
expectError: true,
errorField: "Provider.ClientID",
},
{
name: "encryption key too short",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "too-short",
},
},
expectError: true,
errorField: "Session.EncryptionKey",
},
{
name: "invalid chunk size",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
ChunkSize: 500, // Too small
},
},
expectError: true,
errorField: "Session.ChunkSize",
},
{
name: "invalid max chunks",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
ChunkSize: 4000,
MaxChunks: 0, // Too small
},
},
expectError: true,
errorField: "Session.MaxChunks",
},
{
name: "invalid TLS min version",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
},
Transport: TransportConfig{
TLSMinVersion: "1.0", // Too old
},
},
expectError: true,
errorField: "Transport.TLSMinVersion",
},
{
name: "invalid circuit breaker failure ratio",
config: &UnifiedConfig{
Provider: ProviderConfig{
IssuerURL: "https://auth.example.com",
ClientID: "test-client",
ClientSecret: "secret",
},
Session: SessionConfig{
EncryptionKey: "this-is-a-32-character-key-12345",
},
Circuit: CircuitConfig{
Enabled: true,
FailureRatio: 1.5, // Too high
},
},
expectError: true,
errorField: "Circuit.FailureRatio",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.expectError {
if err == nil {
t.Errorf("Expected validation error for field %s, but got none", tt.errorField)
} else if validationErrs, ok := err.(ValidationErrors); ok {
found := false
for _, e := range validationErrs {
if e.Field == tt.errorField {
found = true
break
}
}
if !found {
t.Errorf("Expected validation error for field %s, but got errors for: %v",
tt.errorField, validationErrs)
}
}
} else {
if err != nil {
t.Errorf("Expected no validation error, but got: %v", err)
}
}
})
}
}
// TestValidationErrorMessage tests validation error formatting
func TestValidationErrorMessage(t *testing.T) {
errs := ValidationErrors{
{
Field: "Provider.IssuerURL",
Message: "is required",
Value: nil,
},
{
Field: "Session.EncryptionKey",
Message: "must be at least 32 characters",
Value: 16,
},
}
errMsg := errs.Error()
if !strings.Contains(errMsg, "Provider.IssuerURL") {
t.Error("Error message should contain field name Provider.IssuerURL")
}
if !strings.Contains(errMsg, "is required") {
t.Error("Error message should contain 'is required'")
}
if !strings.Contains(errMsg, "Session.EncryptionKey") {
t.Error("Error message should contain field name Session.EncryptionKey")
}
if !strings.Contains(errMsg, "must be at least 32 characters") {
t.Error("Error message should contain 'must be at least 32 characters'")
}
}
// TestValidateRedisConfig tests Redis configuration validation
func TestValidateRedisConfig(t *testing.T) {
tests := []struct {
name string
config *RedisConfig
expectError bool
errorMsg string
}{
{
name: "valid standalone config",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeStandalone,
Addr: "localhost:6379",
},
expectError: false,
},
{
name: "missing address for standalone",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeStandalone,
Addr: "",
},
expectError: true,
errorMsg: "Redis address is required",
},
{
name: "valid cluster config",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeCluster,
ClusterAddrs: []string{"localhost:7000", "localhost:7001"},
},
expectError: false,
},
{
name: "missing cluster addresses",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeCluster,
ClusterAddrs: []string{},
},
expectError: true,
errorMsg: "cluster address is required",
},
{
name: "valid sentinel config",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeSentinel,
MasterName: "mymaster",
SentinelAddrs: []string{"localhost:26379"},
},
expectError: false,
},
{
name: "missing master name for sentinel",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeSentinel,
MasterName: "",
SentinelAddrs: []string{"localhost:26379"},
},
expectError: true,
errorMsg: "Master name is required",
},
{
name: "missing sentinel addresses",
config: &RedisConfig{
Enabled: true,
Mode: RedisModeSentinel,
MasterName: "mymaster",
SentinelAddrs: []string{},
},
expectError: true,
errorMsg: "sentinel address is required",
},
{
name: "disabled redis needs no validation",
config: &RedisConfig{
Enabled: false,
},
expectError: false,
},
{
name: "invalid redis mode",
config: &RedisConfig{
Enabled: true,
Mode: "invalid-mode",
},
expectError: true,
errorMsg: "Invalid Redis mode",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.expectError {
if err == nil {
t.Errorf("Expected validation error containing '%s', but got none", tt.errorMsg)
} else if !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error message to contain '%s', but got: %v", tt.errorMsg, err)
}
} else {
if err != nil {
t.Errorf("Expected no validation error, but got: %v", err)
}
}
})
}
}
+116
View File
@@ -0,0 +1,116 @@
package traefikoidc
import (
"encoding/json"
)
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalJSON() ([]byte, error) {
// Build a map manually to avoid type alias issues with yaegi
result := make(map[string]interface{})
// Copy public fields
result["providerURL"] = c.ProviderURL
result["clientID"] = c.ClientID
result["callbackURL"] = c.CallbackURL
result["logoutURL"] = c.LogoutURL
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
result["scopes"] = c.Scopes
result["forceHTTPS"] = c.ForceHTTPS
result["logLevel"] = c.LogLevel
result["rateLimit"] = c.RateLimit
result["excludedURLs"] = c.ExcludedURLs
result["allowedUserDomains"] = c.AllowedUserDomains
result["allowedUsers"] = c.AllowedUsers
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
// Redact sensitive fields
result["clientSecret"] = REDACTED
result["sessionEncryptionKey"] = REDACTED
// Handle Redis config
if c.Redis != nil {
redisMap := make(map[string]interface{})
redisMap["enabled"] = c.Redis.Enabled
redisMap["address"] = c.Redis.Address
redisMap["password"] = REDACTED
redisMap["db"] = c.Redis.DB
redisMap["poolSize"] = c.Redis.PoolSize
redisMap["cacheMode"] = c.Redis.CacheMode
result["redis"] = redisMap
}
return json.Marshal(result)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalYAML() (interface{}, error) {
// Build a map manually to avoid type alias issues with yaegi
result := make(map[string]interface{})
// Copy public fields
result["providerURL"] = c.ProviderURL
result["clientID"] = c.ClientID
result["callbackURL"] = c.CallbackURL
result["logoutURL"] = c.LogoutURL
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
result["scopes"] = c.Scopes
result["forceHTTPS"] = c.ForceHTTPS
result["logLevel"] = c.LogLevel
result["rateLimit"] = c.RateLimit
result["excludedURLs"] = c.ExcludedURLs
result["allowedUserDomains"] = c.AllowedUserDomains
result["allowedUsers"] = c.AllowedUsers
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
// Redact sensitive fields
result["clientSecret"] = REDACTED
result["sessionEncryptionKey"] = REDACTED
// Handle Redis config
if c.Redis != nil {
redisMap := make(map[string]interface{})
redisMap["enabled"] = c.Redis.Enabled
redisMap["address"] = c.Redis.Address
redisMap["password"] = REDACTED
redisMap["db"] = c.Redis.DB
redisMap["poolSize"] = c.Redis.PoolSize
redisMap["cacheMode"] = c.Redis.CacheMode
result["redis"] = redisMap
}
return result, nil
}
// MarshalJSON for RedisConfig to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (r RedisConfig) MarshalJSON() ([]byte, error) {
result := make(map[string]interface{})
result["enabled"] = r.Enabled
result["address"] = r.Address
result["password"] = REDACTED
result["db"] = r.DB
result["poolSize"] = r.PoolSize
result["cacheMode"] = r.CacheMode
return json.Marshal(result)
}
// MarshalYAML for RedisConfig to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (r RedisConfig) MarshalYAML() (interface{}, error) {
result := make(map[string]interface{})
result["enabled"] = r.Enabled
result["address"] = r.Address
result["password"] = REDACTED
result["db"] = r.DB
result["poolSize"] = r.PoolSize
result["cacheMode"] = r.CacheMode
return result, nil
}
+407
View File
@@ -0,0 +1,407 @@
// Package cleanup provides background task management and cleanup functionality.
package cleanup
import (
"context"
"fmt"
"runtime"
"sync"
"sync/atomic"
"time"
)
// Logger defines the logging interface
type Logger interface {
Logf(format string, args ...interface{})
ErrorLogf(format string, args ...interface{})
DebugLogf(format string, args ...interface{})
}
// BackgroundTask represents a recurring background task
type BackgroundTask struct {
name string
interval time.Duration
taskFunc func()
ticker *time.Ticker
stopChan chan bool
isRunning int32
logger Logger
waitGroup *sync.WaitGroup
lastRun time.Time
runCount int64
errorCount int64
mu sync.RWMutex
ctx context.Context
cancelFunc context.CancelFunc
}
// NewBackgroundTask creates a new background task
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger Logger, wg ...*sync.WaitGroup) *BackgroundTask {
var waitGroup *sync.WaitGroup
if len(wg) > 0 && wg[0] != nil {
waitGroup = wg[0]
}
ctx, cancel := context.WithCancel(context.Background())
return &BackgroundTask{
name: name,
interval: interval,
taskFunc: taskFunc,
stopChan: make(chan bool, 1),
isRunning: 0,
logger: logger,
waitGroup: waitGroup,
ctx: ctx,
cancelFunc: cancel,
}
}
// Start begins executing the background task
func (bt *BackgroundTask) Start() {
if !atomic.CompareAndSwapInt32(&bt.isRunning, 0, 1) {
if bt.logger != nil {
bt.logger.Logf("Background task %s is already running", bt.name)
}
return
}
bt.ticker = time.NewTicker(bt.interval)
if bt.waitGroup != nil {
bt.waitGroup.Add(1)
}
go bt.run()
if bt.logger != nil {
bt.logger.Logf("Started background task: %s (interval: %v)", bt.name, bt.interval)
}
}
// Stop stops the background task
func (bt *BackgroundTask) Stop() {
if !atomic.CompareAndSwapInt32(&bt.isRunning, 1, 0) {
if bt.logger != nil {
bt.logger.Logf("Background task %s is not running", bt.name)
}
return
}
// Cancel context
if bt.cancelFunc != nil {
bt.cancelFunc()
}
// Stop ticker
if bt.ticker != nil {
bt.ticker.Stop()
}
// Send stop signal
select {
case bt.stopChan <- true:
case <-time.After(5 * time.Second):
if bt.logger != nil {
bt.logger.ErrorLogf("Timeout stopping background task: %s", bt.name)
}
}
if bt.logger != nil {
bt.logger.Logf("Stopped background task: %s", bt.name)
}
}
// run is the main loop for the background task
func (bt *BackgroundTask) run() {
defer func() {
if bt.waitGroup != nil {
bt.waitGroup.Done()
}
if r := recover(); r != nil {
atomic.AddInt64(&bt.errorCount, 1)
if bt.logger != nil {
bt.logger.ErrorLogf("Background task %s panicked: %v", bt.name, r)
}
}
}()
// Run task immediately on start
bt.executeTask()
for {
select {
case <-bt.ticker.C:
bt.executeTask()
case <-bt.stopChan:
return
case <-bt.ctx.Done():
return
}
}
}
// executeTask runs the task function with error handling
func (bt *BackgroundTask) executeTask() {
defer func() {
if r := recover(); r != nil {
atomic.AddInt64(&bt.errorCount, 1)
if bt.logger != nil {
bt.logger.ErrorLogf("Task %s panicked: %v", bt.name, r)
}
}
}()
bt.mu.Lock()
bt.lastRun = time.Now()
bt.mu.Unlock()
atomic.AddInt64(&bt.runCount, 1)
bt.taskFunc()
}
// GetStats returns statistics about the task
func (bt *BackgroundTask) GetStats() map[string]interface{} {
bt.mu.RLock()
lastRun := bt.lastRun
bt.mu.RUnlock()
return map[string]interface{}{
"name": bt.name,
"interval": bt.interval.String(),
"isRunning": atomic.LoadInt32(&bt.isRunning) == 1,
"lastRun": lastRun.Format(time.RFC3339),
"runCount": atomic.LoadInt64(&bt.runCount),
"errorCount": atomic.LoadInt64(&bt.errorCount),
}
}
// IsRunning returns whether the task is currently running
func (bt *BackgroundTask) IsRunning() bool {
return atomic.LoadInt32(&bt.isRunning) == 1
}
// TaskRegistry manages all background tasks
type TaskRegistry struct {
tasks map[string]*BackgroundTask
mu sync.RWMutex
logger Logger
maxTasks int
circuitBreaker *TaskCircuitBreaker
}
// globalTaskRegistry is the singleton task registry
var (
globalTaskRegistry *TaskRegistry
registryOnce sync.Once
registryMutex sync.Mutex
)
// GetGlobalTaskRegistry returns the global task registry singleton
func GetGlobalTaskRegistry() *TaskRegistry {
registryOnce.Do(func() {
globalTaskRegistry = &TaskRegistry{
tasks: make(map[string]*BackgroundTask),
maxTasks: 100, // Default maximum tasks
}
})
return globalTaskRegistry
}
// ResetGlobalTaskRegistry resets the global task registry (mainly for testing)
func ResetGlobalTaskRegistry() {
registryMutex.Lock()
defer registryMutex.Unlock()
if globalTaskRegistry != nil {
globalTaskRegistry.StopAllTasks()
globalTaskRegistry = nil
}
registryOnce = sync.Once{}
}
// NewTaskRegistry creates a new task registry
func NewTaskRegistry(logger Logger, maxTasks int) *TaskRegistry {
return &TaskRegistry{
tasks: make(map[string]*BackgroundTask),
logger: logger,
maxTasks: maxTasks,
circuitBreaker: NewTaskCircuitBreaker(5, 30*time.Second, logger),
}
}
// RegisterTask registers a new background task
func (tr *TaskRegistry) RegisterTask(name string, task *BackgroundTask) error {
if task == nil {
return fmt.Errorf("task cannot be nil")
}
tr.mu.Lock()
defer tr.mu.Unlock()
// Check if task already exists
if _, exists := tr.tasks[name]; exists {
return fmt.Errorf("task with name %s already exists", name)
}
// Check task limit
if len(tr.tasks) >= tr.maxTasks {
return fmt.Errorf("maximum number of tasks (%d) reached", tr.maxTasks)
}
// Check circuit breaker
if tr.circuitBreaker != nil {
if err := tr.circuitBreaker.CanCreateTask(name); err != nil {
return err
}
}
tr.tasks[name] = task
if tr.logger != nil {
tr.logger.Logf("Registered task: %s", name)
}
return nil
}
// UnregisterTask removes a task from the registry
func (tr *TaskRegistry) UnregisterTask(name string) {
tr.mu.Lock()
defer tr.mu.Unlock()
if task, exists := tr.tasks[name]; exists {
if task.IsRunning() {
task.Stop()
}
delete(tr.tasks, name)
if tr.logger != nil {
tr.logger.Logf("Unregistered task: %s", name)
}
}
}
// GetTask returns a task by name
func (tr *TaskRegistry) GetTask(name string) (*BackgroundTask, bool) {
tr.mu.RLock()
defer tr.mu.RUnlock()
task, exists := tr.tasks[name]
return task, exists
}
// StopAllTasks stops all registered tasks
func (tr *TaskRegistry) StopAllTasks() {
tr.mu.RLock()
tasks := make([]*BackgroundTask, 0, len(tr.tasks))
for _, task := range tr.tasks {
tasks = append(tasks, task)
}
tr.mu.RUnlock()
var wg sync.WaitGroup
for _, task := range tasks {
if task.IsRunning() {
wg.Add(1)
go func(t *BackgroundTask) {
defer wg.Done()
t.Stop()
}(task)
}
}
wg.Wait()
// Clear all tasks from the registry after stopping them
tr.mu.Lock()
tr.tasks = make(map[string]*BackgroundTask)
tr.mu.Unlock()
if tr.logger != nil {
tr.logger.Logf("Stopped all tasks")
}
}
// GetTaskCount returns the number of registered tasks
func (tr *TaskRegistry) GetTaskCount() int {
tr.mu.RLock()
defer tr.mu.RUnlock()
return len(tr.tasks)
}
// CreateSingletonTask creates or retrieves an existing task
func (tr *TaskRegistry) CreateSingletonTask(name string, interval time.Duration,
taskFunc func(), logger Logger, wg ...*sync.WaitGroup) (*BackgroundTask, error) {
// Check if task already exists
if existingTask, exists := tr.GetTask(name); exists {
if existingTask.IsRunning() {
if logger != nil {
logger.Logf("Task %s already exists and is running", name)
}
return existingTask, nil
}
// Task exists but not running, start it
existingTask.Start()
return existingTask, nil
}
// Create new task
task := NewBackgroundTask(name, interval, taskFunc, logger, wg...)
// Register task
if err := tr.RegisterTask(name, task); err != nil {
return nil, err
}
// Start task
task.Start()
return task, nil
}
// GetAllTasks returns all registered tasks
func (tr *TaskRegistry) GetAllTasks() map[string]*BackgroundTask {
tr.mu.RLock()
defer tr.mu.RUnlock()
tasks := make(map[string]*BackgroundTask)
for name, task := range tr.tasks {
tasks[name] = task
}
return tasks
}
// GetStats returns statistics for all tasks
func (tr *TaskRegistry) GetStats() map[string]interface{} {
tr.mu.RLock()
defer tr.mu.RUnlock()
stats := make(map[string]interface{})
stats["totalTasks"] = len(tr.tasks)
runningCount := 0
taskStats := make(map[string]interface{})
for name, task := range tr.tasks {
if task.IsRunning() {
runningCount++
}
taskStats[name] = task.GetStats()
}
stats["runningTasks"] = runningCount
stats["tasks"] = taskStats
// Add memory stats
var m runtime.MemStats
runtime.ReadMemStats(&m)
stats["memory"] = map[string]interface{}{
"alloc": m.Alloc,
"totalAlloc": m.TotalAlloc,
"sys": m.Sys,
"numGC": m.NumGC,
"goroutines": runtime.NumGoroutine(),
}
return stats
}
+449
View File
@@ -0,0 +1,449 @@
// Package cleanup provides background task management and cleanup functionality.
package cleanup
import (
"fmt"
"runtime"
"sync"
"sync/atomic"
"time"
)
// TaskCircuitBreaker prevents task creation failures from cascading
type TaskCircuitBreaker struct {
failureThreshold int32
failureCount int32
lastFailureTime time.Time
timeout time.Duration
state int32 // 0: closed, 1: open
logger Logger
mu sync.RWMutex
taskFailures map[string]int32
}
// CircuitBreakerState represents the state of the circuit breaker
type CircuitBreakerState int32
const (
CircuitBreakerClosed CircuitBreakerState = iota
CircuitBreakerOpen
)
// NewTaskCircuitBreaker creates a new circuit breaker for task management
func NewTaskCircuitBreaker(failureThreshold int32, timeout time.Duration, logger Logger) *TaskCircuitBreaker {
return &TaskCircuitBreaker{
failureThreshold: failureThreshold,
timeout: timeout,
logger: logger,
taskFailures: make(map[string]int32),
}
}
// CanCreateTask checks if a new task can be created
func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
cb.mu.RLock()
defer cb.mu.RUnlock()
// Check circuit breaker state
if atomic.LoadInt32(&cb.state) == int32(CircuitBreakerOpen) {
// Check if timeout has elapsed
if time.Since(cb.lastFailureTime) < cb.timeout {
return fmt.Errorf("circuit breaker open: too many task failures")
}
// Reset circuit breaker
atomic.StoreInt32(&cb.state, int32(CircuitBreakerClosed))
atomic.StoreInt32(&cb.failureCount, 0)
if cb.logger != nil {
cb.logger.Logf("Circuit breaker reset after timeout")
}
}
// Check task-specific failures
if failures, exists := cb.taskFailures[taskName]; exists {
if failures >= cb.failureThreshold {
return fmt.Errorf("task %s has too many failures (%d)", taskName, failures)
}
}
return nil
}
// OnTaskStart records that a task has started
func (cb *TaskCircuitBreaker) OnTaskStart(taskName string) {
// Currently just for tracking, could add rate limiting here
if cb.logger != nil {
cb.logger.DebugLogf("Task %s started", taskName)
}
}
// OnTaskComplete records that a task completed (success or failure)
func (cb *TaskCircuitBreaker) OnTaskComplete(taskName string) {
// Currently just for tracking
if cb.logger != nil {
cb.logger.DebugLogf("Task %s completed", taskName)
}
}
// OnTaskSuccess records a successful task execution
func (cb *TaskCircuitBreaker) OnTaskSuccess(taskName string) {
cb.mu.Lock()
defer cb.mu.Unlock()
// Reset task-specific failure count on success
delete(cb.taskFailures, taskName)
}
// OnTaskFailure records a task failure
func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
cb.mu.Lock()
defer cb.mu.Unlock()
// Increment task-specific failure count
cb.taskFailures[taskName]++
// Increment overall failure count
failures := atomic.AddInt32(&cb.failureCount, 1)
cb.lastFailureTime = time.Now()
if cb.logger != nil {
cb.logger.ErrorLogf("Task %s failed: %v (failure count: %d)", taskName, err, cb.taskFailures[taskName])
}
// Open circuit breaker if threshold reached
if failures >= cb.failureThreshold {
atomic.StoreInt32(&cb.state, int32(CircuitBreakerOpen))
if cb.logger != nil {
cb.logger.ErrorLogf("Circuit breaker opened due to %d failures", failures)
}
}
}
// Reset resets the circuit breaker
func (cb *TaskCircuitBreaker) Reset() {
cb.mu.Lock()
defer cb.mu.Unlock()
atomic.StoreInt32(&cb.state, int32(CircuitBreakerClosed))
atomic.StoreInt32(&cb.failureCount, 0)
cb.taskFailures = make(map[string]int32)
cb.lastFailureTime = time.Time{}
if cb.logger != nil {
cb.logger.Logf("Circuit breaker reset")
}
}
// GetState returns the current state of the circuit breaker
func (cb *TaskCircuitBreaker) GetState() CircuitBreakerState {
return CircuitBreakerState(atomic.LoadInt32(&cb.state))
}
// TaskMemoryMonitor monitors memory usage and can trigger cleanup
type TaskMemoryMonitor struct {
logger Logger
registry *TaskRegistry
memoryThreshold uint64
checkInterval time.Duration
isMonitoring int32
stopChan chan bool
lastCheck time.Time
mu sync.RWMutex
}
var (
globalMemoryMonitor *TaskMemoryMonitor
monitorOnce sync.Once
)
// GetGlobalTaskMemoryMonitor returns the global memory monitor singleton
func GetGlobalTaskMemoryMonitor(logger Logger) *TaskMemoryMonitor {
monitorOnce.Do(func() {
globalMemoryMonitor = NewTaskMemoryMonitor(logger, GetGlobalTaskRegistry())
})
return globalMemoryMonitor
}
// NewTaskMemoryMonitor creates a new memory monitor
func NewTaskMemoryMonitor(logger Logger, registry *TaskRegistry) *TaskMemoryMonitor {
return &TaskMemoryMonitor{
logger: logger,
registry: registry,
memoryThreshold: 1024 * 1024 * 1024, // 1GB default
checkInterval: 1 * time.Minute,
stopChan: make(chan bool, 1),
}
}
// SetMemoryThreshold sets the memory threshold for triggering cleanup
func (tmm *TaskMemoryMonitor) SetMemoryThreshold(bytes uint64) {
tmm.mu.Lock()
defer tmm.mu.Unlock()
tmm.memoryThreshold = bytes
}
// StartMonitoring starts the memory monitoring routine
func (tmm *TaskMemoryMonitor) StartMonitoring() {
if !atomic.CompareAndSwapInt32(&tmm.isMonitoring, 0, 1) {
if tmm.logger != nil {
tmm.logger.Logf("Memory monitor is already running")
}
return
}
go tmm.monitorLoop()
if tmm.logger != nil {
tmm.logger.Logf("Started memory monitoring (threshold: %d bytes, interval: %v)",
tmm.memoryThreshold, tmm.checkInterval)
}
}
// StopMonitoring stops the memory monitoring routine
func (tmm *TaskMemoryMonitor) StopMonitoring() {
if !atomic.CompareAndSwapInt32(&tmm.isMonitoring, 1, 0) {
if tmm.logger != nil {
tmm.logger.Logf("Memory monitor is not running")
}
return
}
select {
case tmm.stopChan <- true:
case <-time.After(5 * time.Second):
if tmm.logger != nil {
tmm.logger.ErrorLogf("Timeout stopping memory monitor")
}
}
if tmm.logger != nil {
tmm.logger.Logf("Stopped memory monitoring")
}
}
// monitorLoop is the main monitoring loop
func (tmm *TaskMemoryMonitor) monitorLoop() {
ticker := time.NewTicker(tmm.checkInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
tmm.checkMemory()
case <-tmm.stopChan:
return
}
}
}
// checkMemory checks current memory usage and triggers cleanup if needed
func (tmm *TaskMemoryMonitor) checkMemory() {
tmm.mu.Lock()
tmm.lastCheck = time.Now()
tmm.mu.Unlock()
var m runtime.MemStats
runtime.ReadMemStats(&m)
if tmm.logger != nil {
tmm.logger.DebugLogf("Memory check - Alloc: %d MB, Sys: %d MB, NumGC: %d",
m.Alloc/1024/1024, m.Sys/1024/1024, m.NumGC)
}
// Check if memory usage exceeds threshold
if m.Alloc > tmm.memoryThreshold {
if tmm.logger != nil {
tmm.logger.Logf("Memory usage (%d MB) exceeds threshold (%d MB), triggering cleanup",
m.Alloc/1024/1024, tmm.memoryThreshold/1024/1024)
}
// Trigger garbage collection
runtime.GC()
// Could also trigger task-specific cleanup here
tmm.triggerTaskCleanup()
}
}
// triggerTaskCleanup triggers cleanup operations on tasks
func (tmm *TaskMemoryMonitor) triggerTaskCleanup() {
if tmm.registry == nil {
return
}
// Get all tasks and potentially pause non-critical ones
tasks := tmm.registry.GetAllTasks()
for name, task := range tasks {
// Could implement task priority here
if tmm.logger != nil {
tmm.logger.DebugLogf("Checking task %s for cleanup opportunities", name)
}
// Tasks could implement a Cleanup() method
_ = task // Placeholder for future cleanup logic
}
}
// GetStats returns memory monitor statistics
func (tmm *TaskMemoryMonitor) GetStats() map[string]interface{} {
tmm.mu.RLock()
lastCheck := tmm.lastCheck
tmm.mu.RUnlock()
var m runtime.MemStats
runtime.ReadMemStats(&m)
return map[string]interface{}{
"isMonitoring": atomic.LoadInt32(&tmm.isMonitoring) == 1,
"lastCheck": lastCheck.Format(time.RFC3339),
"checkInterval": tmm.checkInterval.String(),
"memoryThreshold": tmm.memoryThreshold,
"currentMemory": map[string]interface{}{
"alloc": m.Alloc,
"totalAlloc": m.TotalAlloc,
"sys": m.Sys,
"mallocs": m.Mallocs,
"frees": m.Frees,
"numGC": m.NumGC,
"goroutines": runtime.NumGoroutine(),
},
}
}
// WorkerPool manages a pool of worker goroutines for task execution
type WorkerPool struct {
workers int
taskQueue chan func()
workerWg sync.WaitGroup
isRunning int32
logger Logger
stopChan chan bool
metrics WorkerPoolMetrics
}
// WorkerPoolMetrics tracks worker pool performance
type WorkerPoolMetrics struct {
tasksProcessed int64
tasksQueued int64
tasksFailed int64
avgProcessTime int64 // nanoseconds
}
// NewWorkerPool creates a new worker pool
func NewWorkerPool(workers int, queueSize int, logger Logger) *WorkerPool {
if workers <= 0 {
workers = runtime.NumCPU()
}
if queueSize <= 0 {
queueSize = workers * 10
}
return &WorkerPool{
workers: workers,
taskQueue: make(chan func(), queueSize),
stopChan: make(chan bool),
logger: logger,
}
}
// Start starts the worker pool
func (wp *WorkerPool) Start() {
if !atomic.CompareAndSwapInt32(&wp.isRunning, 0, 1) {
if wp.logger != nil {
wp.logger.Logf("Worker pool is already running")
}
return
}
for i := 0; i < wp.workers; i++ {
wp.workerWg.Add(1)
go wp.worker(i)
}
if wp.logger != nil {
wp.logger.Logf("Started worker pool with %d workers", wp.workers)
}
}
// Stop stops the worker pool
func (wp *WorkerPool) Stop() {
if !atomic.CompareAndSwapInt32(&wp.isRunning, 1, 0) {
if wp.logger != nil {
wp.logger.Logf("Worker pool is not running")
}
return
}
close(wp.stopChan)
close(wp.taskQueue)
wp.workerWg.Wait()
if wp.logger != nil {
wp.logger.Logf("Stopped worker pool")
}
}
// Submit submits a task to the worker pool
func (wp *WorkerPool) Submit(task func()) error {
if atomic.LoadInt32(&wp.isRunning) != 1 {
return fmt.Errorf("worker pool is not running")
}
select {
case wp.taskQueue <- task:
atomic.AddInt64(&wp.metrics.tasksQueued, 1)
return nil
default:
return fmt.Errorf("worker pool queue is full")
}
}
// worker is the main worker routine
func (wp *WorkerPool) worker(id int) {
defer wp.workerWg.Done()
for {
select {
case task, ok := <-wp.taskQueue:
if !ok {
return // Channel closed
}
wp.executeTask(task)
case <-wp.stopChan:
return
}
}
}
// executeTask executes a task with error handling
func (wp *WorkerPool) executeTask(task func()) {
startTime := time.Now()
defer func() {
if r := recover(); r != nil {
atomic.AddInt64(&wp.metrics.tasksFailed, 1)
if wp.logger != nil {
wp.logger.ErrorLogf("Worker pool task panicked: %v", r)
}
}
// Update average process time
duration := time.Since(startTime).Nanoseconds()
processed := atomic.AddInt64(&wp.metrics.tasksProcessed, 1)
currentAvg := atomic.LoadInt64(&wp.metrics.avgProcessTime)
newAvg := (currentAvg*(processed-1) + duration) / processed
atomic.StoreInt64(&wp.metrics.avgProcessTime, newAvg)
}()
task()
}
// GetMetrics returns worker pool metrics
func (wp *WorkerPool) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"workers": wp.workers,
"isRunning": atomic.LoadInt32(&wp.isRunning) == 1,
"queueSize": len(wp.taskQueue),
"queueCapacity": cap(wp.taskQueue),
"tasksProcessed": atomic.LoadInt64(&wp.metrics.tasksProcessed),
"tasksQueued": atomic.LoadInt64(&wp.metrics.tasksQueued),
"tasksFailed": atomic.LoadInt64(&wp.metrics.tasksFailed),
"avgProcessTime": time.Duration(atomic.LoadInt64(&wp.metrics.avgProcessTime)),
}
}
+320
View File
@@ -0,0 +1,320 @@
// Package compat provides backward compatibility layer during refactoring
package compat
import (
"fmt"
"reflect"
"sync"
)
// CompatibilityLayer provides backward compatibility during the migration
type CompatibilityLayer struct {
mappings map[string]string // old path -> new path
converters map[string]Converter
deprecations map[string]string // deprecated field -> warning message
mu sync.RWMutex
}
// Converter is a function that converts old value format to new format
type Converter func(oldValue interface{}) (newValue interface{}, err error)
// Global compatibility layer instance
var (
layer *CompatibilityLayer
layerOnce sync.Once
)
// GetLayer returns the global compatibility layer instance
func GetLayer() *CompatibilityLayer {
layerOnce.Do(func() {
layer = &CompatibilityLayer{
mappings: make(map[string]string),
converters: make(map[string]Converter),
deprecations: make(map[string]string),
}
layer.initialize()
})
return layer
}
// initialize sets up default compatibility mappings
func (c *CompatibilityLayer) initialize() {
// Configuration path mappings (old -> new)
c.RegisterMapping("ProviderURL", "Provider.IssuerURL")
c.RegisterMapping("ClientID", "Provider.ClientID")
c.RegisterMapping("ClientSecret", "Provider.ClientSecret")
c.RegisterMapping("CallbackURL", "Provider.RedirectURL")
c.RegisterMapping("LogoutURL", "Provider.LogoutURL")
c.RegisterMapping("SessionEncryptionKey", "Session.EncryptionKey")
c.RegisterMapping("Scopes", "Provider.Scopes")
c.RegisterMapping("RateLimit", "Middleware.RateLimit")
c.RegisterMapping("RefreshGracePeriodSeconds", "Token.RefreshGracePeriod")
// Redis configuration mappings
c.RegisterMapping("RedisAddr", "Redis.Addresses[0]")
c.RegisterMapping("RedisPassword", "Redis.Password")
c.RegisterMapping("RedisDB", "Redis.DB")
// Session configuration mappings
c.RegisterMapping("SessionName", "Session.Name")
c.RegisterMapping("SessionMaxAge", "Session.MaxAge")
c.RegisterMapping("SessionSecret", "Session.Secret")
c.RegisterMapping("SessionChunkSize", "Session.ChunkSize")
// Security configuration mappings
c.RegisterMapping("ForceHTTPS", "Security.ForceHTTPS")
c.RegisterMapping("EnablePKCE", "Security.EnablePKCE")
c.RegisterMapping("AllowedUsers", "Security.AllowedUsers")
c.RegisterMapping("AllowedUserDomains", "Security.AllowedUserDomains")
c.RegisterMapping("AllowedRolesAndGroups", "Security.AllowedRolesAndGroups")
c.RegisterMapping("ExcludedURLs", "Security.ExcludedURLs")
// Register converters for complex transformations
c.RegisterConverter("RefreshGracePeriodSeconds", func(oldValue interface{}) (interface{}, error) {
// Convert seconds (int) to duration string
if seconds, ok := oldValue.(int); ok {
return fmt.Sprintf("%ds", seconds), nil
}
return oldValue, nil
})
// Register deprecations
c.RegisterDeprecation("LogLevel", "LogLevel is deprecated, use Logging.Level instead")
c.RegisterDeprecation("HTTPClient", "HTTPClient is deprecated, configure via Transport settings")
}
// RegisterMapping registers a field mapping from old to new path
func (c *CompatibilityLayer) RegisterMapping(oldPath, newPath string) {
c.mu.Lock()
defer c.mu.Unlock()
c.mappings[oldPath] = newPath
}
// RegisterConverter registers a value converter for a field
func (c *CompatibilityLayer) RegisterConverter(field string, converter Converter) {
c.mu.Lock()
defer c.mu.Unlock()
c.converters[field] = converter
}
// RegisterDeprecation registers a deprecation warning for a field
func (c *CompatibilityLayer) RegisterDeprecation(field, message string) {
c.mu.Lock()
defer c.mu.Unlock()
c.deprecations[field] = message
}
// GetMapping returns the new path for an old configuration path
func (c *CompatibilityLayer) GetMapping(oldPath string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
newPath, exists := c.mappings[oldPath]
return newPath, exists
}
// Convert applies conversion logic to a value
func (c *CompatibilityLayer) Convert(field string, value interface{}) (interface{}, error) {
c.mu.RLock()
converter, exists := c.converters[field]
c.mu.RUnlock()
if !exists {
return value, nil
}
return converter(value)
}
// CheckDeprecation checks if a field is deprecated and returns warning message
func (c *CompatibilityLayer) CheckDeprecation(field string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
message, deprecated := c.deprecations[field]
return message, deprecated
}
// MigrateMap migrates an old configuration map to new structure
func (c *CompatibilityLayer) MigrateMap(oldConfig map[string]interface{}) (map[string]interface{}, []string) {
newConfig := make(map[string]interface{})
warnings := []string{}
for key, value := range oldConfig {
// Check for deprecation
if warning, deprecated := c.CheckDeprecation(key); deprecated {
warnings = append(warnings, warning)
}
// Get new path
newPath, hasMappming := c.GetMapping(key)
if !hasMappming {
// No mapping, use as-is
newConfig[key] = value
continue
}
// Apply converter if exists
convertedValue, err := c.Convert(key, value)
if err != nil {
warnings = append(warnings, fmt.Sprintf("Failed to convert %s: %v", key, err))
convertedValue = value
}
// Set value at new path
setNestedValue(newConfig, newPath, convertedValue)
}
return newConfig, warnings
}
// setNestedValue sets a value in a nested map structure using dot notation
func setNestedValue(m map[string]interface{}, path string, value interface{}) {
keys := splitPath(path)
if len(keys) == 0 {
return
}
current := m
for i := 0; i < len(keys)-1; i++ {
key := keys[i]
// Check if this key has array notation
if isArrayPath(key) {
// Handle array notation (e.g., "Addresses[0]")
continue // Skip array handling for now, will be handled in actual migration
}
if _, exists := current[key]; !exists {
current[key] = make(map[string]interface{})
}
// Ensure it's a map
if next, ok := current[key].(map[string]interface{}); ok {
current = next
} else {
// Can't traverse further, create new map
newMap := make(map[string]interface{})
current[key] = newMap
current = newMap
}
}
// Set the final value
finalKey := keys[len(keys)-1]
current[finalKey] = value
}
// splitPath splits a configuration path into segments
func splitPath(path string) []string {
segments := []string{}
current := ""
for i := 0; i < len(path); i++ {
if path[i] == '.' {
if current != "" {
segments = append(segments, current)
current = ""
}
} else {
current += string(path[i])
}
}
if current != "" {
segments = append(segments, current)
}
return segments
}
// isArrayPath checks if a path segment contains array notation
func isArrayPath(segment string) bool {
for _, char := range segment {
if char == '[' {
return true
}
}
return false
}
// ConfigAdapter provides an adapter interface for old code to work with new config
type ConfigAdapter struct {
newConfig interface{}
oldPaths map[string]func() interface{}
mu sync.RWMutex
}
// NewConfigAdapter creates a new configuration adapter
func NewConfigAdapter(newConfig interface{}) *ConfigAdapter {
adapter := &ConfigAdapter{
newConfig: newConfig,
oldPaths: make(map[string]func() interface{}),
}
return adapter
}
// RegisterGetter registers a getter function for an old path
func (a *ConfigAdapter) RegisterGetter(oldPath string, getter func() interface{}) {
a.mu.Lock()
defer a.mu.Unlock()
a.oldPaths[oldPath] = getter
}
// Get retrieves a value using old path notation
func (a *ConfigAdapter) Get(oldPath string) (interface{}, bool) {
a.mu.RLock()
getter, exists := a.oldPaths[oldPath]
a.mu.RUnlock()
if !exists {
// Try to get from new config using reflection
return a.getFromNewConfig(oldPath)
}
return getter(), true
}
// getFromNewConfig attempts to retrieve value from new config using reflection
func (a *ConfigAdapter) getFromNewConfig(path string) (interface{}, bool) {
// Check if there's a mapping for this path
compat := GetLayer()
if newPath, hasMappming := compat.GetMapping(path); hasMappming {
return a.getNestedField(newPath)
}
// Try direct access
return a.getNestedField(path)
}
// getNestedField retrieves a nested field value using reflection
func (a *ConfigAdapter) getNestedField(path string) (interface{}, bool) {
segments := splitPath(path)
if len(segments) == 0 {
return nil, false
}
v := reflect.ValueOf(a.newConfig)
// Dereference pointer if needed
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
for _, segment := range segments {
if v.Kind() != reflect.Struct {
return nil, false
}
field := v.FieldByName(segment)
if !field.IsValid() {
return nil, false
}
v = field
}
if v.IsValid() && v.CanInterface() {
return v.Interface(), true
}
return nil, false
}
+235
View File
@@ -0,0 +1,235 @@
// Package features provides feature flag management for safe rollback during refactoring
package features
import (
"os"
"strings"
"sync"
"sync/atomic"
)
// FeatureFlag represents a feature flag for controlling new functionality
type FeatureFlag struct {
name string
description string
enabled atomic.Bool
mu sync.RWMutex
callbacks []func(bool)
}
// FeatureManager manages all feature flags in the application
type FeatureManager struct {
flags map[string]*FeatureFlag
mu sync.RWMutex
}
var (
// Global feature manager instance
manager *FeatureManager
managerOnce sync.Once
)
// Feature flag names
const (
// UseUnifiedConfig enables the new unified configuration system
UseUnifiedConfig = "USE_UNIFIED_CONFIG"
// UseNewFileStructure enables the new modularized file structure
UseNewFileStructure = "USE_NEW_FILE_STRUCTURE"
// UseStandardErrors enables the standardized error package
UseStandardErrors = "USE_STANDARD_ERRORS"
// UseEnhancedLogging enables the enhanced logging system
UseEnhancedLogging = "USE_ENHANCED_LOGGING"
// UseOptimizedTests enables the consolidated test suite
UseOptimizedTests = "USE_OPTIMIZED_TESTS"
// UseRedisRESP enables the custom Redis RESP implementation
UseRedisRESP = "USE_REDIS_RESP"
)
// GetManager returns the global feature manager instance
func GetManager() *FeatureManager {
managerOnce.Do(func() {
manager = &FeatureManager{
flags: make(map[string]*FeatureFlag),
}
manager.initialize()
})
return manager
}
// initialize sets up default feature flags
func (m *FeatureManager) initialize() {
// Phase 0: Feature flags setup
m.Register(UseUnifiedConfig, "Enable unified configuration package", false)
m.Register(UseNewFileStructure, "Enable modularized file structure", false)
m.Register(UseStandardErrors, "Enable standardized error handling", false)
m.Register(UseEnhancedLogging, "Enable enhanced logging system", false)
m.Register(UseOptimizedTests, "Enable optimized test suite", false)
m.Register(UseRedisRESP, "Enable custom Redis RESP implementation", false)
// Load from environment variables
m.LoadFromEnv()
}
// Register creates a new feature flag
func (m *FeatureManager) Register(name, description string, defaultValue bool) {
m.mu.Lock()
defer m.mu.Unlock()
flag := &FeatureFlag{
name: name,
description: description,
callbacks: make([]func(bool), 0),
}
flag.enabled.Store(defaultValue)
m.flags[name] = flag
}
// IsEnabled checks if a feature flag is enabled
func (m *FeatureManager) IsEnabled(name string) bool {
m.mu.RLock()
flag, exists := m.flags[name]
m.mu.RUnlock()
if !exists {
return false
}
return flag.enabled.Load()
}
// Enable turns on a feature flag
func (m *FeatureManager) Enable(name string) {
m.setFlag(name, true)
}
// Disable turns off a feature flag
func (m *FeatureManager) Disable(name string) {
m.setFlag(name, false)
}
// Toggle switches a feature flag state
func (m *FeatureManager) Toggle(name string) {
m.mu.RLock()
flag, exists := m.flags[name]
m.mu.RUnlock()
if exists {
newValue := !flag.enabled.Load()
m.setFlag(name, newValue)
}
}
// setFlag updates a feature flag value and triggers callbacks
func (m *FeatureManager) setFlag(name string, value bool) {
m.mu.RLock()
flag, exists := m.flags[name]
m.mu.RUnlock()
if !exists {
return
}
oldValue := flag.enabled.Swap(value)
// Only trigger callbacks if value actually changed
if oldValue != value {
flag.mu.RLock()
callbacks := flag.callbacks
flag.mu.RUnlock()
for _, callback := range callbacks {
callback(value)
}
}
}
// OnChange registers a callback to be called when a feature flag changes
func (m *FeatureManager) OnChange(name string, callback func(bool)) {
m.mu.RLock()
flag, exists := m.flags[name]
m.mu.RUnlock()
if exists {
flag.mu.Lock()
flag.callbacks = append(flag.callbacks, callback)
flag.mu.Unlock()
}
}
// LoadFromEnv loads feature flag values from environment variables
func (m *FeatureManager) LoadFromEnv() {
m.mu.RLock()
flags := make(map[string]*FeatureFlag)
for name, flag := range m.flags {
flags[name] = flag
}
m.mu.RUnlock()
for name, flag := range flags {
envVar := "FEATURE_" + name
if value := os.Getenv(envVar); value != "" {
enabled := strings.ToLower(value) == "true" || value == "1"
flag.enabled.Store(enabled)
}
}
}
// GetAll returns all feature flags and their states
func (m *FeatureManager) GetAll() map[string]bool {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]bool)
for name, flag := range m.flags {
result[name] = flag.enabled.Load()
}
return result
}
// Reset resets all feature flags to their default values
func (m *FeatureManager) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
for _, flag := range m.flags {
flag.enabled.Store(false)
flag.callbacks = make([]func(bool), 0)
}
}
// Helper functions for common checks
// IsUnifiedConfigEnabled checks if unified config is enabled
func IsUnifiedConfigEnabled() bool {
return GetManager().IsEnabled(UseUnifiedConfig)
}
// IsNewFileStructureEnabled checks if new file structure is enabled
func IsNewFileStructureEnabled() bool {
return GetManager().IsEnabled(UseNewFileStructure)
}
// IsStandardErrorsEnabled checks if standard errors are enabled
func IsStandardErrorsEnabled() bool {
return GetManager().IsEnabled(UseStandardErrors)
}
// IsEnhancedLoggingEnabled checks if enhanced logging is enabled
func IsEnhancedLoggingEnabled() bool {
return GetManager().IsEnabled(UseEnhancedLogging)
}
// IsOptimizedTestsEnabled checks if optimized tests are enabled
func IsOptimizedTestsEnabled() bool {
return GetManager().IsEnabled(UseOptimizedTests)
}
// IsRedisRESPEnabled checks if custom Redis RESP is enabled
func IsRedisRESPEnabled() bool {
return GetManager().IsEnabled(UseRedisRESP)
}
+307
View File
@@ -0,0 +1,307 @@
// Package recovery provides error recovery and resilience mechanisms for OIDC authentication.
package recovery
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
)
// ErrorRecoveryMechanism defines the interface for error recovery strategies.
// It provides a common contract for implementing various resilience patterns
// such as circuit breakers, retry mechanisms, and fallback strategies.
type ErrorRecoveryMechanism interface {
// ExecuteWithContext runs a function with error recovery using the provided context
ExecuteWithContext(ctx context.Context, fn func() error) error
// Reset resets the recovery mechanism state
Reset()
// IsAvailable checks if the mechanism is currently available for use
IsAvailable() bool
// GetMetrics returns metrics about the recovery mechanism's performance
GetMetrics() map[string]interface{}
}
// Logger defines the logging interface
type Logger interface {
Logf(format string, args ...interface{})
ErrorLogf(format string, args ...interface{})
DebugLogf(format string, args ...interface{})
}
// BaseRecoveryMechanism provides common functionality and metrics tracking
// for all recovery mechanism implementations. It handles request counting,
// success/failure tracking, and timestamp management in a thread-safe manner.
type BaseRecoveryMechanism struct {
// name identifies the recovery mechanism instance
name string
// logger provides structured logging capabilities
logger Logger
// Metrics tracked with atomic operations for thread safety
totalRequests int64
successCount int64
failureCount int64
lastSuccessStr string
lastFailureStr string
// mutexes for thread-safe timestamp updates
successMutex sync.RWMutex
failureMutex sync.RWMutex
}
// NewBaseRecoveryMechanism creates a new base recovery mechanism with the given name and logger.
// This serves as the foundation for specific recovery mechanism implementations.
// Parameters:
// - name: Identifier for this recovery mechanism instance
// - logger: Logger instance for outputting diagnostic information
//
// Returns:
// - A new BaseRecoveryMechanism instance with initialized metrics
func NewBaseRecoveryMechanism(name string, logger Logger) *BaseRecoveryMechanism {
return &BaseRecoveryMechanism{
name: name,
logger: logger,
totalRequests: 0,
successCount: 0,
failureCount: 0,
lastSuccessStr: "never",
lastFailureStr: "never",
}
}
// RecordRequest increments the total request counter.
// This method is thread-safe using atomic operations.
func (b *BaseRecoveryMechanism) RecordRequest() {
atomic.AddInt64(&b.totalRequests, 1)
}
// RecordSuccess increments the success counter and updates the last success timestamp.
// This method is thread-safe using atomic operations for counters
// and mutex protection for timestamp updates.
func (b *BaseRecoveryMechanism) RecordSuccess() {
atomic.AddInt64(&b.successCount, 1)
b.successMutex.Lock()
b.lastSuccessStr = time.Now().Format(time.RFC3339)
b.successMutex.Unlock()
}
// RecordFailure increments the failure counter and updates the last failure timestamp.
// This method is thread-safe using atomic operations for counters
// and mutex protection for timestamp updates.
func (b *BaseRecoveryMechanism) RecordFailure() {
atomic.AddInt64(&b.failureCount, 1)
b.failureMutex.Lock()
b.lastFailureStr = time.Now().Format(time.RFC3339)
b.failureMutex.Unlock()
}
// GetBaseMetrics returns comprehensive metrics about the recovery mechanism.
// Includes request counts, success/failure rates, timing information,
// and calculated percentages. All access is thread-safe.
func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
total := atomic.LoadInt64(&b.totalRequests)
success := atomic.LoadInt64(&b.successCount)
failure := atomic.LoadInt64(&b.failureCount)
b.successMutex.RLock()
lastSuccess := b.lastSuccessStr
b.successMutex.RUnlock()
b.failureMutex.RLock()
lastFailure := b.lastFailureStr
b.failureMutex.RUnlock()
metrics := map[string]interface{}{
"name": b.name,
"totalRequests": total,
"successCount": success,
"failureCount": failure,
"lastSuccess": lastSuccess,
"lastFailure": lastFailure,
}
// Calculate success and failure rates
if total > 0 {
successRate := float64(success) / float64(total) * 100
failureRate := float64(failure) / float64(total) * 100
metrics["successRate"] = fmt.Sprintf("%.2f%%", successRate)
metrics["failureRate"] = fmt.Sprintf("%.2f%%", failureRate)
} else {
metrics["successRate"] = "0.00%"
metrics["failureRate"] = "0.00%"
}
return metrics
}
// LogInfo logs an informational message with the mechanism name as prefix.
// Provides consistent logging format across all recovery mechanisms.
func (b *BaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
if b.logger != nil {
b.logger.Logf("[%s] %s", b.name, fmt.Sprintf(format, args...))
}
}
// LogError logs an error message with the mechanism name as prefix.
// Used for reporting failures and error conditions in recovery mechanisms.
func (b *BaseRecoveryMechanism) LogError(format string, args ...interface{}) {
if b.logger != nil {
b.logger.ErrorLogf("[%s] %s", b.name, fmt.Sprintf(format, args...))
}
}
// LogDebug logs a debug message with the mechanism name as prefix.
// Useful for detailed troubleshooting of recovery mechanism behavior.
func (b *BaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
if b.logger != nil {
b.logger.DebugLogf("[%s] %s", b.name, fmt.Sprintf(format, args...))
}
}
// ErrorType represents different categories of errors
type ErrorType int
const (
// ErrorTypeUnknown represents an unknown error type
ErrorTypeUnknown ErrorType = iota
// ErrorTypeNetwork represents network-related errors
ErrorTypeNetwork
// ErrorTypeTimeout represents timeout errors
ErrorTypeTimeout
// ErrorTypeAuthentication represents authentication errors
ErrorTypeAuthentication
// ErrorTypeRateLimit represents rate limiting errors
ErrorTypeRateLimit
// ErrorTypeServerError represents server errors (5xx)
ErrorTypeServerError
// ErrorTypeClientError represents client errors (4xx)
ErrorTypeClientError
)
// HTTPError represents an HTTP error with status code and message
type HTTPError struct {
StatusCode int
Message string
Body []byte
Headers map[string]string
}
// Error implements the error interface
func (e *HTTPError) Error() string {
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Message)
}
// IsRetryable checks if the HTTP error is retryable
func (e *HTTPError) IsRetryable() bool {
// Retry on 5xx errors and specific 4xx errors
return e.StatusCode >= 500 || e.StatusCode == 429 || e.StatusCode == 408
}
// OIDCError represents an OIDC-specific error
type OIDCError struct {
Code string
Description string
URI string
State string
}
// Error implements the error interface
func (e *OIDCError) Error() string {
if e.Description != "" {
return fmt.Sprintf("OIDC error %s: %s", e.Code, e.Description)
}
return fmt.Sprintf("OIDC error: %s", e.Code)
}
// IsRetryable checks if the OIDC error is retryable
func (e *OIDCError) IsRetryable() bool {
// Some OIDC errors are retryable
switch e.Code {
case "temporarily_unavailable", "server_error":
return true
default:
return false
}
}
// FallbackMechanism provides a simple fallback recovery strategy
type FallbackMechanism struct {
*BaseRecoveryMechanism
fallbackFunc func() error
}
// NewFallbackMechanism creates a new fallback mechanism
func NewFallbackMechanism(name string, logger Logger, fallbackFunc func() error) *FallbackMechanism {
return &FallbackMechanism{
BaseRecoveryMechanism: NewBaseRecoveryMechanism(name, logger),
fallbackFunc: fallbackFunc,
}
}
// ExecuteWithContext executes the primary function and falls back on error
func (f *FallbackMechanism) ExecuteWithContext(ctx context.Context, fn func() error) error {
f.RecordRequest()
// Check context first
select {
case <-ctx.Done():
f.RecordFailure()
return ctx.Err()
default:
}
// Try primary function
if err := fn(); err != nil {
f.LogInfo("Primary function failed: %v, trying fallback", err)
// Try fallback
if f.fallbackFunc != nil {
if fallbackErr := f.fallbackFunc(); fallbackErr == nil {
f.RecordSuccess()
return nil
} else {
f.LogError("Fallback also failed: %v", fallbackErr)
f.RecordFailure()
return fmt.Errorf("both primary and fallback failed: primary=%v, fallback=%v", err, fallbackErr)
}
}
f.RecordFailure()
return err
}
f.RecordSuccess()
return nil
}
// Reset resets the fallback mechanism state
func (f *FallbackMechanism) Reset() {
// Reset metrics
atomic.StoreInt64(&f.totalRequests, 0)
atomic.StoreInt64(&f.successCount, 0)
atomic.StoreInt64(&f.failureCount, 0)
f.successMutex.Lock()
f.lastSuccessStr = "never"
f.successMutex.Unlock()
f.failureMutex.Lock()
f.lastFailureStr = "never"
f.failureMutex.Unlock()
}
// IsAvailable checks if the fallback mechanism is available
func (f *FallbackMechanism) IsAvailable() bool {
// Fallback is always available
return true
}
// GetMetrics returns metrics about the fallback mechanism
func (f *FallbackMechanism) GetMetrics() map[string]interface{} {
metrics := f.GetBaseMetrics()
metrics["type"] = "fallback"
metrics["hasFallback"] = f.fallbackFunc != nil
return metrics
}
+336
View File
@@ -0,0 +1,336 @@
// Package recovery provides error recovery and resilience mechanisms for OIDC authentication.
package recovery
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
)
// CircuitBreakerState represents the current state of the circuit breaker
type CircuitBreakerState int
const (
// CircuitBreakerClosed allows all requests to pass through
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen blocks all requests
CircuitBreakerOpen
// CircuitBreakerHalfOpen allows limited requests for testing
CircuitBreakerHalfOpen
)
// String returns the string representation of the circuit breaker state
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// CircuitBreakerConfig defines configuration for the circuit breaker
type CircuitBreakerConfig struct {
// FailureThreshold is the number of failures before opening the circuit
FailureThreshold int
// SuccessThreshold is the number of successes in half-open state before closing
SuccessThreshold int
// Timeout is the duration to wait before transitioning from open to half-open
Timeout time.Duration
// MaxRequests is the maximum number of requests allowed in half-open state
MaxRequests int
}
// DefaultCircuitBreakerConfig returns sensible default configuration
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
return CircuitBreakerConfig{
FailureThreshold: 5,
SuccessThreshold: 2,
Timeout: 30 * time.Second,
MaxRequests: 3,
}
}
// CircuitBreaker implements the circuit breaker pattern for fault tolerance.
// It prevents cascading failures by temporarily blocking requests to a failing service.
type CircuitBreaker struct {
*BaseRecoveryMechanism
config CircuitBreakerConfig
// State management
state int32 // atomic: CircuitBreakerState
lastStateChange time.Time
stateMutex sync.RWMutex
// Failure tracking
consecutiveFailures int32 // atomic
consecutiveSuccesses int32 // atomic
// Half-open state management
halfOpenRequests int32 // atomic
}
// NewCircuitBreaker creates a new circuit breaker with the given configuration
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger) *CircuitBreaker {
return &CircuitBreaker{
BaseRecoveryMechanism: NewBaseRecoveryMechanism("CircuitBreaker", logger),
config: config,
state: int32(CircuitBreakerClosed),
lastStateChange: time.Now(),
consecutiveFailures: 0,
consecutiveSuccesses: 0,
halfOpenRequests: 0,
}
}
// ExecuteWithContext executes a function with circuit breaker protection
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
cb.RecordRequest()
// Check if request is allowed
if !cb.allowRequest() {
cb.RecordFailure()
return fmt.Errorf("circuit breaker is open")
}
// Execute the function
err := fn()
if err != nil {
cb.recordFailure()
return err
}
cb.recordSuccess()
return nil
}
// Execute executes a function with circuit breaker protection (legacy method)
func (cb *CircuitBreaker) Execute(fn func() error) error {
return cb.ExecuteWithContext(context.Background(), fn)
}
// allowRequest determines if a request should be allowed based on the circuit state
func (cb *CircuitBreaker) allowRequest() bool {
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
switch state {
case CircuitBreakerClosed:
return true
case CircuitBreakerOpen:
// Check if timeout has elapsed
cb.stateMutex.RLock()
lastChange := cb.lastStateChange
cb.stateMutex.RUnlock()
if time.Since(lastChange) > cb.config.Timeout {
// Transition to half-open
cb.transitionToHalfOpen()
return cb.allowHalfOpenRequest()
}
return false
case CircuitBreakerHalfOpen:
return cb.allowHalfOpenRequest()
default:
return false
}
}
// allowHalfOpenRequest checks if a request is allowed in half-open state
func (cb *CircuitBreaker) allowHalfOpenRequest() bool {
current := atomic.AddInt32(&cb.halfOpenRequests, 1)
if current <= int32(cb.config.MaxRequests) {
return true
}
atomic.AddInt32(&cb.halfOpenRequests, -1)
return false
}
// recordFailure records a failure and potentially opens the circuit
func (cb *CircuitBreaker) recordFailure() {
cb.RecordFailure()
failures := atomic.AddInt32(&cb.consecutiveFailures, 1)
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
if state == CircuitBreakerClosed && failures >= int32(cb.config.FailureThreshold) {
cb.transitionToOpen()
} else if state == CircuitBreakerHalfOpen {
cb.transitionToOpen()
}
}
// recordSuccess records a success and potentially closes the circuit
func (cb *CircuitBreaker) recordSuccess() {
cb.RecordSuccess()
successes := atomic.AddInt32(&cb.consecutiveSuccesses, 1)
atomic.StoreInt32(&cb.consecutiveFailures, 0)
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
if state == CircuitBreakerHalfOpen && successes >= int32(cb.config.SuccessThreshold) {
cb.transitionToClosed()
}
}
// transitionToClosed transitions the circuit to closed state
func (cb *CircuitBreaker) transitionToClosed() {
if atomic.CompareAndSwapInt32(&cb.state, int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
cb.stateMutex.Lock()
cb.lastStateChange = time.Now()
cb.stateMutex.Unlock()
atomic.StoreInt32(&cb.consecutiveFailures, 0)
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
atomic.StoreInt32(&cb.halfOpenRequests, 0)
cb.LogInfo("Circuit breaker closed")
}
}
// transitionToOpen transitions the circuit to open state
func (cb *CircuitBreaker) transitionToOpen() {
oldState := atomic.SwapInt32(&cb.state, int32(CircuitBreakerOpen))
if oldState != int32(CircuitBreakerOpen) {
cb.stateMutex.Lock()
cb.lastStateChange = time.Now()
cb.stateMutex.Unlock()
atomic.StoreInt32(&cb.consecutiveFailures, 0)
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
atomic.StoreInt32(&cb.halfOpenRequests, 0)
cb.LogError("Circuit breaker opened due to failures")
}
}
// transitionToHalfOpen transitions the circuit to half-open state
func (cb *CircuitBreaker) transitionToHalfOpen() {
if atomic.CompareAndSwapInt32(&cb.state, int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
cb.stateMutex.Lock()
cb.lastStateChange = time.Now()
cb.stateMutex.Unlock()
atomic.StoreInt32(&cb.consecutiveFailures, 0)
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
atomic.StoreInt32(&cb.halfOpenRequests, 0)
cb.LogInfo("Circuit breaker half-open, testing recovery")
}
}
// GetState returns the current state of the circuit breaker
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
return CircuitBreakerState(atomic.LoadInt32(&cb.state))
}
// Reset resets the circuit breaker to closed state
func (cb *CircuitBreaker) Reset() {
atomic.StoreInt32(&cb.state, int32(CircuitBreakerClosed))
cb.stateMutex.Lock()
cb.lastStateChange = time.Now()
cb.stateMutex.Unlock()
atomic.StoreInt32(&cb.consecutiveFailures, 0)
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
atomic.StoreInt32(&cb.halfOpenRequests, 0)
// Reset base metrics
atomic.StoreInt64(&cb.totalRequests, 0)
atomic.StoreInt64(&cb.successCount, 0)
atomic.StoreInt64(&cb.failureCount, 0)
cb.LogInfo("Circuit breaker reset to closed state")
}
// IsAvailable returns true if the circuit breaker is not fully open
func (cb *CircuitBreaker) IsAvailable() bool {
state := cb.GetState()
return state != CircuitBreakerOpen || time.Since(cb.getLastStateChange()) > cb.config.Timeout
}
// getLastStateChange returns the last state change time safely
func (cb *CircuitBreaker) getLastStateChange() time.Time {
cb.stateMutex.RLock()
defer cb.stateMutex.RUnlock()
return cb.lastStateChange
}
// GetMetrics returns comprehensive metrics about the circuit breaker
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
metrics := cb.GetBaseMetrics()
state := cb.GetState()
metrics["state"] = state.String()
metrics["consecutiveFailures"] = atomic.LoadInt32(&cb.consecutiveFailures)
metrics["consecutiveSuccesses"] = atomic.LoadInt32(&cb.consecutiveSuccesses)
metrics["halfOpenRequests"] = atomic.LoadInt32(&cb.halfOpenRequests)
cb.stateMutex.RLock()
metrics["lastStateChange"] = cb.lastStateChange.Format(time.RFC3339)
metrics["timeSinceLastChange"] = time.Since(cb.lastStateChange).String()
cb.stateMutex.RUnlock()
// Configuration
metrics["config"] = map[string]interface{}{
"failureThreshold": cb.config.FailureThreshold,
"successThreshold": cb.config.SuccessThreshold,
"timeout": cb.config.Timeout.String(),
"maxRequests": cb.config.MaxRequests,
}
// Health indicator
switch state {
case CircuitBreakerClosed:
metrics["health"] = "healthy"
case CircuitBreakerHalfOpen:
metrics["health"] = "recovering"
case CircuitBreakerOpen:
if time.Since(cb.getLastStateChange()) > cb.config.Timeout {
metrics["health"] = "ready-to-recover"
} else {
metrics["health"] = "unhealthy"
}
}
return metrics
}
// ForceOpen forces the circuit breaker to open state
func (cb *CircuitBreaker) ForceOpen() {
atomic.StoreInt32(&cb.state, int32(CircuitBreakerOpen))
cb.stateMutex.Lock()
cb.lastStateChange = time.Now()
cb.stateMutex.Unlock()
cb.LogInfo("Circuit breaker forced open")
}
// ForceClosed forces the circuit breaker to closed state
func (cb *CircuitBreaker) ForceClosed() {
atomic.StoreInt32(&cb.state, int32(CircuitBreakerClosed))
cb.stateMutex.Lock()
cb.lastStateChange = time.Now()
cb.stateMutex.Unlock()
atomic.StoreInt32(&cb.consecutiveFailures, 0)
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
atomic.StoreInt32(&cb.halfOpenRequests, 0)
cb.LogInfo("Circuit breaker forced closed")
}
+391
View File
@@ -0,0 +1,391 @@
// Package recovery provides error recovery and resilience mechanisms for OIDC authentication.
package recovery
import (
"context"
"fmt"
"math"
"math/rand"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
)
// RetryConfig defines configuration for the retry executor
type RetryConfig struct {
// MaxAttempts is the maximum number of retry attempts
MaxAttempts int
// InitialDelay is the initial delay between retries
InitialDelay time.Duration
// MaxDelay is the maximum delay between retries
MaxDelay time.Duration
// Multiplier is the backoff multiplier
Multiplier float64
// RandomizationFactor adds jitter to delays (0.0 to 1.0)
RandomizationFactor float64
// RetryableErrors defines which errors should trigger a retry
RetryableErrors []string
// RetryableStatusCodes defines which HTTP status codes should trigger a retry
RetryableStatusCodes []int
}
// DefaultRetryConfig returns sensible default retry configuration
func DefaultRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 30 * time.Second,
Multiplier: 2.0,
RandomizationFactor: 0.1,
RetryableErrors: []string{"connection refused", "timeout", "EOF"},
RetryableStatusCodes: []int{408, 429, 500, 502, 503, 504},
}
}
// RetryExecutor implements retry logic with exponential backoff
type RetryExecutor struct {
*BaseRecoveryMechanism
config RetryConfig
// Metrics
totalRetries int64
maxRetriesHit int64
lastRetryTime time.Time
retryTimeMutex sync.RWMutex
}
// NewRetryExecutor creates a new retry executor with the given configuration
func NewRetryExecutor(config RetryConfig, logger Logger) *RetryExecutor {
if config.MaxAttempts < 1 {
config.MaxAttempts = 1
}
if config.Multiplier < 1.0 {
config.Multiplier = 1.0
}
return &RetryExecutor{
BaseRecoveryMechanism: NewBaseRecoveryMechanism("RetryExecutor", logger),
config: config,
totalRetries: 0,
maxRetriesHit: 0,
}
}
// ExecuteWithContext executes a function with retry logic
func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error) error {
re.RecordRequest()
var lastErr error
for attempt := 1; attempt <= re.config.MaxAttempts; attempt++ {
// Check context before attempting
select {
case <-ctx.Done():
re.RecordFailure()
return ctx.Err()
default:
}
// Execute the function
lastErr = fn()
if lastErr == nil {
re.RecordSuccess()
if attempt > 1 {
re.LogInfo("Succeeded after %d attempts", attempt)
}
return nil
}
// Check if error is retryable
if !re.isRetryableError(lastErr) {
re.LogDebug("Error is not retryable: %v", lastErr)
re.RecordFailure()
return lastErr
}
// Don't retry if this was the last attempt
if attempt >= re.config.MaxAttempts {
atomic.AddInt64(&re.maxRetriesHit, 1)
re.LogError("Max retries (%d) exhausted", re.config.MaxAttempts)
break
}
// Calculate and apply delay
delay := re.calculateDelay(attempt)
re.LogInfo("Attempt %d failed: %v, retrying in %v", attempt, lastErr, delay)
atomic.AddInt64(&re.totalRetries, 1)
re.retryTimeMutex.Lock()
re.lastRetryTime = time.Now()
re.retryTimeMutex.Unlock()
select {
case <-time.After(delay):
// Continue to next attempt
case <-ctx.Done():
re.RecordFailure()
return fmt.Errorf("retry cancelled: %w", ctx.Err())
}
}
re.RecordFailure()
return fmt.Errorf("all retry attempts failed: %w", lastErr)
}
// Execute executes a function with retry logic (legacy method)
func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
return re.ExecuteWithContext(ctx, fn)
}
// isRetryableError determines if an error should trigger a retry
func (re *RetryExecutor) isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
// Check for retryable error patterns
for _, pattern := range re.config.RetryableErrors {
if strings.Contains(errStr, strings.ToLower(pattern)) {
return true
}
}
// Check for HTTP errors
if httpErr, ok := err.(*HTTPError); ok {
for _, code := range re.config.RetryableStatusCodes {
if httpErr.StatusCode == code {
return true
}
}
// Also retry on any 5xx error
if httpErr.StatusCode >= 500 && httpErr.StatusCode < 600 {
return true
}
}
// Check for OIDC errors
if oidcErr, ok := err.(*OIDCError); ok {
return oidcErr.IsRetryable()
}
// Check for context errors (don't retry these)
if err == context.Canceled || err == context.DeadlineExceeded {
return false
}
// Default: don't retry unknown errors
return false
}
// calculateDelay calculates the delay before the next retry attempt
func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
// Exponential backoff
delay := float64(re.config.InitialDelay) * math.Pow(re.config.Multiplier, float64(attempt-1))
// Cap at max delay
if delay > float64(re.config.MaxDelay) {
delay = float64(re.config.MaxDelay)
}
// Add jitter
if re.config.RandomizationFactor > 0 {
jitter := delay * re.config.RandomizationFactor
minDelay := delay - jitter
maxDelay := delay + jitter
delay = minDelay + rand.Float64()*(maxDelay-minDelay)
}
return time.Duration(delay)
}
// Reset resets the retry executor state
func (re *RetryExecutor) Reset() {
atomic.StoreInt64(&re.totalRetries, 0)
atomic.StoreInt64(&re.maxRetriesHit, 0)
atomic.StoreInt64(&re.totalRequests, 0)
atomic.StoreInt64(&re.successCount, 0)
atomic.StoreInt64(&re.failureCount, 0)
re.retryTimeMutex.Lock()
re.lastRetryTime = time.Time{}
re.retryTimeMutex.Unlock()
}
// IsAvailable always returns true for retry executor
func (re *RetryExecutor) IsAvailable() bool {
return true
}
// GetMetrics returns comprehensive metrics about the retry executor
func (re *RetryExecutor) GetMetrics() map[string]interface{} {
metrics := re.GetBaseMetrics()
metrics["totalRetries"] = atomic.LoadInt64(&re.totalRetries)
metrics["maxRetriesHit"] = atomic.LoadInt64(&re.maxRetriesHit)
re.retryTimeMutex.RLock()
if !re.lastRetryTime.IsZero() {
metrics["lastRetryTime"] = re.lastRetryTime.Format(time.RFC3339)
metrics["timeSinceLastRetry"] = time.Since(re.lastRetryTime).String()
} else {
metrics["lastRetryTime"] = "never"
}
re.retryTimeMutex.RUnlock()
// Configuration
metrics["config"] = map[string]interface{}{
"maxAttempts": re.config.MaxAttempts,
"initialDelay": re.config.InitialDelay.String(),
"maxDelay": re.config.MaxDelay.String(),
"multiplier": re.config.Multiplier,
"randomizationFactor": re.config.RandomizationFactor,
}
// Calculate average retries per request
totalRequests := atomic.LoadInt64(&re.totalRequests)
if totalRequests > 0 {
avgRetries := float64(atomic.LoadInt64(&re.totalRetries)) / float64(totalRequests)
metrics["averageRetriesPerRequest"] = fmt.Sprintf("%.2f", avgRetries)
}
return metrics
}
// RecoveryMetrics aggregates metrics from multiple recovery mechanisms
type RecoveryMetrics struct {
mechanisms map[string]ErrorRecoveryMechanism
mu sync.RWMutex
}
// NewRecoveryMetrics creates a new recovery metrics aggregator
func NewRecoveryMetrics() *RecoveryMetrics {
return &RecoveryMetrics{
mechanisms: make(map[string]ErrorRecoveryMechanism),
}
}
// RegisterMechanism registers a recovery mechanism for metrics collection
func (rm *RecoveryMetrics) RegisterMechanism(name string, mechanism ErrorRecoveryMechanism) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.mechanisms[name] = mechanism
}
// UnregisterMechanism removes a recovery mechanism from metrics collection
func (rm *RecoveryMetrics) UnregisterMechanism(name string) {
rm.mu.Lock()
defer rm.mu.Unlock()
delete(rm.mechanisms, name)
}
// GetAllMetrics returns aggregated metrics from all registered mechanisms
func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
rm.mu.RLock()
defer rm.mu.RUnlock()
allMetrics := make(map[string]interface{})
for name, mechanism := range rm.mechanisms {
allMetrics[name] = mechanism.GetMetrics()
}
// Add summary statistics
totalRequests := int64(0)
totalSuccesses := int64(0)
totalFailures := int64(0)
for _, mechanism := range rm.mechanisms {
metrics := mechanism.GetMetrics()
if requests, ok := metrics["totalRequests"].(int64); ok {
totalRequests += requests
}
if successes, ok := metrics["successCount"].(int64); ok {
totalSuccesses += successes
}
if failures, ok := metrics["failureCount"].(int64); ok {
totalFailures += failures
}
}
allMetrics["summary"] = map[string]interface{}{
"totalMechanisms": len(rm.mechanisms),
"totalRequests": totalRequests,
"totalSuccesses": totalSuccesses,
"totalFailures": totalFailures,
}
if totalRequests > 0 {
successRate := float64(totalSuccesses) / float64(totalRequests) * 100
allMetrics["summary"].(map[string]interface{})["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
}
return allMetrics
}
// GetMechanismMetrics returns metrics for a specific mechanism
func (rm *RecoveryMetrics) GetMechanismMetrics(name string) (map[string]interface{}, bool) {
rm.mu.RLock()
defer rm.mu.RUnlock()
if mechanism, exists := rm.mechanisms[name]; exists {
return mechanism.GetMetrics(), true
}
return nil, false
}
// HealthCheck performs a health check on all registered mechanisms
func (rm *RecoveryMetrics) HealthCheck() map[string]interface{} {
rm.mu.RLock()
defer rm.mu.RUnlock()
health := make(map[string]interface{})
healthyCount := 0
unhealthyCount := 0
for name, mechanism := range rm.mechanisms {
if mechanism.IsAvailable() {
health[name] = "healthy"
healthyCount++
} else {
health[name] = "unhealthy"
unhealthyCount++
}
}
overallHealth := "healthy"
if unhealthyCount > 0 {
if healthyCount > 0 {
overallHealth = "degraded"
} else {
overallHealth = "unhealthy"
}
}
return map[string]interface{}{
"status": overallHealth,
"mechanisms": health,
"healthy": healthyCount,
"unhealthy": unhealthyCount,
"timestamp": time.Now().Format(time.RFC3339),
}
}
// HTTPMetricsHandler creates an HTTP handler for serving recovery metrics
func (rm *RecoveryMetrics) HTTPMetricsHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
metrics := rm.GetAllMetrics()
health := rm.HealthCheck()
response := map[string]interface{}{
"metrics": metrics,
"health": health,
}
// Would normally use json.Marshal here, but keeping it simple for the module
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "%v", response)
}
}
+312
View File
@@ -0,0 +1,312 @@
// Package token provides token management functionality for OIDC authentication.
package token
import (
"fmt"
"net/http"
"sync"
"time"
)
// TokenCache manages cached verified tokens
type TokenCache struct {
cache CacheInterface
blacklist CacheInterface
logger LoggerInterface
metrics MetricsInterface
cleanupTicker *time.Ticker
cleanupStop chan bool
mu sync.RWMutex
maxTTL time.Duration
}
// NewTokenCache creates a new token cache manager
func NewTokenCache(cache, blacklist CacheInterface, logger LoggerInterface, metrics MetricsInterface, maxTTL time.Duration) *TokenCache {
return &TokenCache{
cache: cache,
blacklist: blacklist,
logger: logger,
metrics: metrics,
maxTTL: maxTTL,
cleanupStop: make(chan bool),
}
}
// CacheToken stores a verified token with its claims in cache
func (tc *TokenCache) CacheToken(token string, claims map[string]interface{}) {
if token == "" || len(claims) == 0 {
return
}
tc.mu.Lock()
defer tc.mu.Unlock()
// Add timestamp for TTL management
claimsWithMeta := make(map[string]interface{})
for k, v := range claims {
claimsWithMeta[k] = v
}
claimsWithMeta["_cached_at"] = time.Now().Unix()
tc.cache.Set(token, claimsWithMeta)
tc.logger.Logf("Cached verified token (claims count: %d)", len(claims))
}
// GetCachedToken retrieves a token's claims from cache if present and valid
func (tc *TokenCache) GetCachedToken(token string) (map[string]interface{}, bool) {
if token == "" {
return nil, false
}
tc.mu.RLock()
defer tc.mu.RUnlock()
claims, exists := tc.cache.Get(token)
if !exists || len(claims) == 0 {
return nil, false
}
// Check if token is blacklisted
if tc.isBlacklisted(token, claims) {
tc.cache.Delete(token)
return nil, false
}
// Check cache TTL
if cachedAt, ok := claims["_cached_at"].(int64); ok {
if time.Since(time.Unix(cachedAt, 0)) > tc.maxTTL {
tc.cache.Delete(token)
return nil, false
}
}
// Check token expiry from claims
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
tc.cache.Delete(token)
return nil, false
}
}
tc.logger.Logf("Token found in cache (valid)")
return claims, true
}
// InvalidateToken removes a token from cache and adds it to blacklist
func (tc *TokenCache) InvalidateToken(token string) {
if token == "" {
return
}
tc.mu.Lock()
defer tc.mu.Unlock()
// Remove from cache
tc.cache.Delete(token)
// Add to blacklist
if tc.blacklist != nil {
tc.blacklist.Set(token, map[string]interface{}{
"invalidated_at": time.Now().Unix(),
"reason": "manual_invalidation",
})
// Also blacklist JTI if present
if claims, exists := tc.cache.Get(token); exists {
if jti, ok := claims["jti"].(string); ok && jti != "" {
tc.blacklist.Set(jti, map[string]interface{}{
"invalidated_at": time.Now().Unix(),
"reason": "jti_invalidation",
})
}
}
}
tc.logger.Logf("Token invalidated and blacklisted")
}
// StartCleanup starts the background cleanup process for expired tokens
func (tc *TokenCache) StartCleanup(interval time.Duration) {
tc.mu.Lock()
defer tc.mu.Unlock()
if tc.cleanupTicker != nil {
return // Already running
}
tc.cleanupTicker = time.NewTicker(interval)
go func() {
for {
select {
case <-tc.cleanupTicker.C:
tc.cleanupExpiredTokens()
case <-tc.cleanupStop:
return
}
}
}()
tc.logger.Logf("Started token cache cleanup (interval: %v)", interval)
}
// StopCleanup stops the background cleanup process
func (tc *TokenCache) StopCleanup() {
tc.mu.Lock()
defer tc.mu.Unlock()
if tc.cleanupTicker != nil {
tc.cleanupTicker.Stop()
tc.cleanupTicker = nil
close(tc.cleanupStop)
tc.cleanupStop = make(chan bool)
tc.logger.Logf("Stopped token cache cleanup")
}
}
// cleanupExpiredTokens removes expired tokens from cache
func (tc *TokenCache) cleanupExpiredTokens() {
tc.mu.Lock()
defer tc.mu.Unlock()
// This would need to iterate through cache entries
// Since we're using an interface, we'd need to add a method to get all keys
// For now, this is a placeholder that would be implemented based on the actual cache implementation
tc.logger.Logf("Running token cache cleanup")
}
// isBlacklisted checks if a token or its JTI is blacklisted
func (tc *TokenCache) isBlacklisted(token string, claims map[string]interface{}) bool {
if tc.blacklist == nil {
return false
}
// Check token itself
if blacklisted, exists := tc.blacklist.Get(token); exists && blacklisted != nil {
return true
}
// Check JTI
if jti, ok := claims["jti"].(string); ok && jti != "" {
if blacklisted, exists := tc.blacklist.Get(jti); exists && blacklisted != nil {
return true
}
}
return false
}
// TokenBlacklist manages blacklisted tokens
type TokenBlacklist struct {
blacklist CacheInterface
logger LoggerInterface
mu sync.RWMutex
}
// NewTokenBlacklist creates a new token blacklist manager
func NewTokenBlacklist(blacklist CacheInterface, logger LoggerInterface) *TokenBlacklist {
return &TokenBlacklist{
blacklist: blacklist,
logger: logger,
}
}
// Add adds a token to the blacklist
func (tb *TokenBlacklist) Add(token string, reason string) {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.blacklist.Set(token, map[string]interface{}{
"blacklisted_at": time.Now().Unix(),
"reason": reason,
})
tb.logger.Logf("Token added to blacklist (reason: %s)", reason)
}
// AddJTI adds a JTI to the blacklist for replay detection
func (tb *TokenBlacklist) AddJTI(jti string) {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.blacklist.Set(jti, map[string]interface{}{
"blacklisted_at": time.Now().Unix(),
"reason": "jti_replay_detection",
})
tb.logger.Logf("JTI added to blacklist for replay detection")
}
// IsBlacklisted checks if a token is blacklisted
func (tb *TokenBlacklist) IsBlacklisted(token string) bool {
tb.mu.RLock()
defer tb.mu.RUnlock()
if blacklisted, exists := tb.blacklist.Get(token); exists && blacklisted != nil {
return true
}
return false
}
// IsJTIBlacklisted checks if a JTI is blacklisted
func (tb *TokenBlacklist) IsJTIBlacklisted(jti string) bool {
tb.mu.RLock()
defer tb.mu.RUnlock()
if blacklisted, exists := tb.blacklist.Get(jti); exists && blacklisted != nil {
return true
}
return false
}
// TokenRevocationManager handles token revocation with providers
type TokenRevocationManager struct {
clientID string
clientSecret string
revocationURL string
httpClient *http.Client
logger LoggerInterface
blacklist *TokenBlacklist
}
// NewTokenRevocationManager creates a new revocation manager
func NewTokenRevocationManager(clientID, clientSecret, revocationURL string, httpClient *http.Client, logger LoggerInterface, blacklist *TokenBlacklist) *TokenRevocationManager {
return &TokenRevocationManager{
clientID: clientID,
clientSecret: clientSecret,
revocationURL: revocationURL,
httpClient: httpClient,
logger: logger,
blacklist: blacklist,
}
}
// RevokeToken revokes a token locally and optionally with the provider
func (trm *TokenRevocationManager) RevokeToken(token string, tokenType string, withProvider bool) error {
// Add to local blacklist immediately
trm.blacklist.Add(token, fmt.Sprintf("revoked_%s", tokenType))
// Parse token to get JTI
if jwt, err := parseJWT(token); err == nil {
if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" {
trm.blacklist.AddJTI(jti)
}
}
// Revoke with provider if requested
if withProvider && trm.revocationURL != "" {
return trm.revokeWithProvider(token, tokenType)
}
return nil
}
// revokeWithProvider sends revocation request to the OIDC provider
func (trm *TokenRevocationManager) revokeWithProvider(token, tokenType string) error {
// Implementation would send HTTP request to revocation endpoint
// This is simplified for module structure
trm.logger.Logf("Revoking %s with provider", tokenType)
return nil
}
+265
View File
@@ -0,0 +1,265 @@
// Package token provides token management functionality for OIDC authentication.
package token
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)
// Introspector handles token introspection operations
type Introspector struct {
clientID string
clientSecret string
introspectionURL string
httpClient *http.Client
logger LoggerInterface
groupsClaimPath []string
rolesClaimPath []string
extractClaimsRegex string
}
// NewIntrospector creates a new token introspector
func NewIntrospector(clientID, clientSecret, introspectionURL string, httpClient *http.Client, logger LoggerInterface, groupsClaimPath, rolesClaimPath []string, extractClaimsRegex string) *Introspector {
return &Introspector{
clientID: clientID,
clientSecret: clientSecret,
introspectionURL: introspectionURL,
httpClient: httpClient,
logger: logger,
groupsClaimPath: groupsClaimPath,
rolesClaimPath: rolesClaimPath,
extractClaimsRegex: extractClaimsRegex,
}
}
// IntrospectToken performs token introspection with the OIDC provider
func (i *Introspector) IntrospectToken(token string, tokenTypeHint string) (*IntrospectionResponse, error) {
if i.introspectionURL == "" {
return nil, fmt.Errorf("introspection endpoint not configured")
}
data := url.Values{}
data.Set("token", token)
if tokenTypeHint != "" {
data.Set("token_type_hint", tokenTypeHint)
}
data.Set("client_id", i.clientID)
data.Set("client_secret", i.clientSecret)
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, i.introspectionURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create introspection request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := i.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("introspection request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read introspection response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("introspection failed with status %d: %s", resp.StatusCode, string(body))
}
var introspectResp IntrospectionResponse
if err := json.Unmarshal(body, &introspectResp); err != nil {
return nil, fmt.Errorf("failed to parse introspection response: %w", err)
}
// Parse any extra fields
var raw map[string]interface{}
if err := json.Unmarshal(body, &raw); err == nil {
introspectResp.Extra = make(map[string]interface{})
for k, v := range raw {
switch k {
case "active", "scope", "client_id", "username", "token_type",
"exp", "iat", "nbf", "sub", "aud", "iss", "jti":
// Skip standard fields
default:
introspectResp.Extra[k] = v
}
}
}
return &introspectResp, nil
}
// ExtractGroupsAndRoles extracts groups and roles from an ID token
func (i *Introspector) ExtractGroupsAndRoles(idToken string) ([]string, []string, error) {
jwt, err := parseJWT(idToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse ID token: %w", err)
}
groups := i.extractClaimValues(jwt.Claims, i.groupsClaimPath)
roles := i.extractClaimValues(jwt.Claims, i.rolesClaimPath)
i.logger.Logf("Extracted %d groups and %d roles from ID token", len(groups), len(roles))
return groups, roles, nil
}
// DetectTokenType analyzes a token and determines its type
func (i *Introspector) DetectTokenType(token string) (string, error) {
jwt, err := parseJWT(token)
if err != nil {
return "", fmt.Errorf("failed to parse token: %w", err)
}
// Check for ID token characteristics
if aud, ok := jwt.Claims["aud"]; ok {
switch v := aud.(type) {
case string:
if v == i.clientID {
return "id_token", nil
}
case []interface{}:
for _, a := range v {
if str, ok := a.(string); ok && str == i.clientID {
return "id_token", nil
}
}
}
}
// Check for access token characteristics
if scope, ok := jwt.Claims["scope"]; ok {
if _, isString := scope.(string); isString {
return "access_token", nil
}
}
// Check token_use claim (AWS Cognito specific)
if tokenUse, ok := jwt.Claims["token_use"]; ok {
if use, isString := tokenUse.(string); isString {
switch use {
case "id":
return "id_token", nil
case "access":
return "access_token", nil
}
}
}
// Check typ header
if typ, ok := jwt.Header["typ"]; ok {
if typStr, isString := typ.(string); isString {
switch strings.ToLower(typStr) {
case "jwt", "at+jwt":
return "access_token", nil
case "id+jwt":
return "id_token", nil
}
}
}
return "unknown", nil
}
// extractClaimValues extracts claim values from JWT claims using a path
func (i *Introspector) extractClaimValues(claims map[string]interface{}, claimPath []string) []string {
if len(claimPath) == 0 {
return nil
}
var result []string
current := claims
for idx, key := range claimPath {
if idx == len(claimPath)-1 {
// Last key - extract the values
if val, exists := current[key]; exists {
result = i.extractStringSlice(val)
}
} else {
// Navigate deeper
if next, ok := current[key].(map[string]interface{}); ok {
current = next
} else {
break
}
}
}
return result
}
// extractStringSlice converts various types to string slice
func (i *Introspector) extractStringSlice(val interface{}) []string {
switch v := val.(type) {
case []interface{}:
var result []string
for _, item := range v {
if str, ok := item.(string); ok {
result = append(result, str)
}
}
return result
case []string:
return v
case string:
if v != "" {
// Handle comma-separated or space-separated values
if strings.Contains(v, ",") {
return strings.Split(v, ",")
}
return []string{v}
}
}
return nil
}
// parseJWT parses a JWT token without verification
func parseJWT(token string) (*JWT, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
header, err := decodeSegment(parts[0])
if err != nil {
return nil, fmt.Errorf("failed to decode header: %w", err)
}
claims, err := decodeSegment(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode claims: %w", err)
}
return &JWT{
Header: header,
Claims: claims,
}, nil
}
// decodeSegment decodes a base64url encoded JWT segment
func decodeSegment(seg string) (map[string]interface{}, error) {
// Add padding if necessary
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
decoded, err := base64.URLEncoding.DecodeString(seg)
if err != nil {
return nil, fmt.Errorf("failed to decode segment: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(decoded, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal segment: %w", err)
}
return result, nil
}
+182
View File
@@ -0,0 +1,182 @@
// Package token provides token management functionality for OIDC authentication.
package token
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// Refresher handles token refresh operations
type Refresher struct {
clientID string
clientSecret string
tokenURL string
httpClient *http.Client
logger LoggerInterface
metrics MetricsInterface
sessionManager SessionManagerInterface
tokenCache CacheInterface
verifier TokenVerifier
}
// NewRefresher creates a new token refresher
func NewRefresher(clientID, clientSecret, tokenURL string, httpClient *http.Client, logger LoggerInterface, metrics MetricsInterface, sessionManager SessionManagerInterface, tokenCache CacheInterface, verifier TokenVerifier) *Refresher {
return &Refresher{
clientID: clientID,
clientSecret: clientSecret,
tokenURL: tokenURL,
httpClient: httpClient,
logger: logger,
metrics: metrics,
sessionManager: sessionManager,
tokenCache: tokenCache,
verifier: verifier,
}
}
// RefreshToken attempts to refresh expired tokens using the refresh token.
// Returns true if refresh was successful or not needed, false if refresh failed and session should be terminated.
func (r *Refresher) RefreshToken(rw http.ResponseWriter, req *http.Request, session SessionDataInterface) bool {
if session == nil {
r.logger.ErrorLogf("RefreshToken: Session is nil")
return false
}
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
r.logger.Logf("No refresh token available, cannot refresh")
return false
}
r.logger.Logf("Attempting to refresh expired tokens")
tokenResp, err := r.GetNewTokenWithRefreshToken(refreshToken)
if err != nil {
r.logger.ErrorLogf("Failed to refresh tokens: %v", err)
r.metrics.RecordTokenRefreshError()
return false
}
// Parse expiry from expires_in
var idTokenExpiry, accessTokenExpiry time.Time
if tokenResp.ExpiresIn > 0 {
expiry := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
idTokenExpiry = expiry
accessTokenExpiry = expiry
}
// Update session with new tokens
if tokenResp.IDToken != "" && tokenResp.AccessToken != "" {
session.SetTokens(
tokenResp.IDToken,
tokenResp.AccessToken,
tokenResp.RefreshToken,
idTokenExpiry,
accessTokenExpiry,
)
} else if tokenResp.IDToken != "" {
session.SetIDToken(tokenResp.IDToken, idTokenExpiry)
if tokenResp.RefreshToken != "" {
session.SetRefreshToken(tokenResp.RefreshToken)
}
} else if tokenResp.AccessToken != "" {
session.SetAccessToken(tokenResp.AccessToken, accessTokenExpiry)
if tokenResp.RefreshToken != "" {
session.SetRefreshToken(tokenResp.RefreshToken)
}
}
// Clear old tokens from cache
if oldIDToken := session.GetIDToken(); oldIDToken != "" {
r.tokenCache.Delete(oldIDToken)
}
if oldAccessToken := session.GetAccessToken(); oldAccessToken != "" {
r.tokenCache.Delete(oldAccessToken)
}
// Verify and cache new tokens
if tokenResp.IDToken != "" {
if err := r.verifier.VerifyToken(tokenResp.IDToken); err != nil {
r.logger.ErrorLogf("Failed to verify refreshed ID token: %v", err)
return false
}
}
if tokenResp.AccessToken != "" {
if err := r.verifier.VerifyToken(tokenResp.AccessToken); err != nil {
r.logger.ErrorLogf("Failed to verify refreshed access token: %v", err)
return false
}
}
// Save updated session
if err := session.SaveToCache(); err != nil {
r.logger.ErrorLogf("Failed to save refreshed session: %v", err)
return false
}
r.metrics.RecordTokenRefresh()
r.logger.Logf("Successfully refreshed tokens")
return true
}
// GetNewTokenWithRefreshToken exchanges a refresh token for new tokens
func (r *Refresher) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
return r.exchangeToken("refresh_token", refreshToken, "", "")
}
// exchangeToken performs the actual token exchange with the provider
func (r *Refresher) exchangeToken(grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
data := url.Values{}
data.Set("client_id", r.clientID)
data.Set("client_secret", r.clientSecret)
data.Set("grant_type", grantType)
switch grantType {
case "authorization_code":
data.Set("code", codeOrToken)
if redirectURL != "" {
data.Set("redirect_uri", redirectURL)
}
if codeVerifier != "" {
data.Set("code_verifier", codeVerifier)
}
case "refresh_token":
data.Set("refresh_token", codeOrToken)
default:
return nil, fmt.Errorf("unsupported grant type: %s", grantType)
}
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, r.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token exchange request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read token response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
var tokenResp TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
return &tokenResp, nil
}
+184
View File
@@ -0,0 +1,184 @@
package token
import (
"net/http"
"time"
)
// TokenResponse represents the response from a token endpoint.
// It contains the tokens and additional metadata returned by the OIDC provider.
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
// JWT represents a parsed JSON Web Token.
// It contains the decoded header and claims from the token.
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
}
// JWK represents a JSON Web Key used for token verification.
// It contains the cryptographic key material and metadata.
type JWK struct {
Kty string `json:"kty"`
Use string `json:"use"`
Kid string `json:"kid"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
X5c []string `json:"x5c,omitempty"`
}
// JWKS represents a JSON Web Key Set.
// It contains multiple public keys that can be used for token verification.
type JWKS struct {
Keys []JWK `json:"keys"`
}
// TokenVerifier interface for verifying tokens
type TokenVerifier interface {
VerifyToken(token string) error
}
// TokenExchanger interface for exchanging tokens
type TokenExchanger interface {
GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error)
ExchangeCodeForToken(ctx interface{}, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
}
// ClaimsExtractor function type for extracting claims from tokens
type ClaimsExtractor func(token string) (map[string]interface{}, error)
// CacheInterface defines cache operations for storing token data
type CacheInterface interface {
Get(key string) (map[string]interface{}, bool)
Set(key string, value map[string]interface{})
Delete(key string)
}
// TokenCacheInterface defines methods for token caching operations
type TokenCacheInterface interface {
CacheToken(token string, claims map[string]interface{})
GetCachedToken(token string) (map[string]interface{}, bool)
InvalidateToken(token string)
StartCleanup(interval time.Duration)
StopCleanup()
}
// LoggerInterface defines logging methods
type LoggerInterface interface {
Logf(format string, args ...interface{})
ErrorLogf(format string, args ...interface{})
}
// MetricsInterface defines metrics tracking methods
type MetricsInterface interface {
RecordTokenRefresh()
RecordTokenRefreshError()
}
// SessionManagerInterface defines session management methods
type SessionManagerInterface interface {
GetSession(sessionID string) (SessionDataInterface, error)
SaveSession(session SessionDataInterface) error
}
// SessionDataInterface defines minimal session interface needed by refresher
type SessionDataInterface interface {
GetRefreshToken() string
GetIDToken() string
GetAccessToken() string
GetIDTokenExpiry() time.Time
GetAccessTokenExpiry() time.Time
SetIDToken(token string, expiry time.Time)
SetAccessToken(token string, expiry time.Time)
SetRefreshToken(token string)
SetTokens(idToken, accessToken, refreshToken string, idExpiry, accessExpiry time.Time)
SaveToCache() error
}
// IntrospectorInterface defines methods for token introspection
type IntrospectorInterface interface {
IntrospectToken(token string, tokenTypeHint string) (*IntrospectionResponse, error)
ExtractGroupsAndRoles(idToken string) ([]string, []string, error)
DetectTokenType(token string) (string, error)
}
// IntrospectionResponse represents the response from token introspection
type IntrospectionResponse struct {
Active bool `json:"active"`
Scope string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
Username string `json:"username,omitempty"`
TokenType string `json:"token_type,omitempty"`
Exp int64 `json:"exp,omitempty"`
Iat int64 `json:"iat,omitempty"`
Nbf int64 `json:"nbf,omitempty"`
Sub string `json:"sub,omitempty"`
Aud interface{} `json:"aud,omitempty"`
Iss string `json:"iss,omitempty"`
Jti string `json:"jti,omitempty"`
Extra map[string]interface{} `json:"-"`
}
// RefresherInterface defines methods for token refresh operations
type RefresherInterface interface {
RefreshToken(rw http.ResponseWriter, req *http.Request, session SessionDataInterface) bool
GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error)
}
// RevokeTokenEntry represents a token revocation request
type RevokeTokenEntry struct {
Token string
TokenType string
RevokedAt time.Time
Reason string
}
// ValidatorConfig contains configuration for the token validator
type ValidatorConfig struct {
ClientID string
Audience string
IssuerURL string
JwksURL string
TokenCache TokenCacheInterface
TokenBlacklist CacheInterface
TokenTypeCache CacheInterface
JwkCache interface{}
HTTPClient *http.Client
Limiter interface{}
ExtractClaimsFunc ClaimsExtractor
TokenVerifier TokenVerifier
DisableReplayDetection bool
SuppressDiagnosticLogs bool
MetadataMu interface{} // sync.RWMutex
Logger interface{}
}
// Constants for token validation
const (
DefaultBlacklistDuration = 24 * time.Hour
TokenCacheDuration = 5 * time.Minute
)
// Token type constants
const (
TokenTypeAccess = "ACCESS_TOKEN"
TokenTypeID = "ID_TOKEN"
TokenTypeRefresh = "REFRESH_TOKEN"
TokenTypeUnknown = "UNKNOWN"
)
// Provider constants
const (
ProviderGoogle = "google"
ProviderAzure = "azure"
ProviderOkta = "okta"
ProviderAuth0 = "auth0"
)
+355
View File
@@ -0,0 +1,355 @@
package token
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"time"
)
// Validator handles token validation operations
type Validator struct {
clientID string
audience string
issuerURL string
jwksURL string
tokenCache TokenCacheInterface
tokenBlacklist CacheInterface
tokenTypeCache CacheInterface
jwkCache interface{} // JWK cache interface
httpClient *http.Client
limiter interface{} // Rate limiter interface
extractClaimsFunc ClaimsExtractor
tokenVerifier TokenVerifier
disableReplayDetection bool
suppressDiagnosticLogs bool
metadataMu *sync.RWMutex
logger interface{} // Logger interface
}
// NewValidator creates a new token validator
func NewValidator(config ValidatorConfig) *Validator {
var metadataMu *sync.RWMutex
if config.MetadataMu != nil {
if mu, ok := config.MetadataMu.(*sync.RWMutex); ok {
metadataMu = mu
}
}
return &Validator{
clientID: config.ClientID,
audience: config.Audience,
issuerURL: config.IssuerURL,
jwksURL: config.JwksURL,
tokenCache: config.TokenCache,
tokenBlacklist: config.TokenBlacklist,
tokenTypeCache: config.TokenTypeCache,
jwkCache: config.JwkCache,
httpClient: config.HTTPClient,
limiter: config.Limiter,
extractClaimsFunc: config.ExtractClaimsFunc,
tokenVerifier: config.TokenVerifier,
disableReplayDetection: config.DisableReplayDetection,
suppressDiagnosticLogs: config.SuppressDiagnosticLogs,
metadataMu: metadataMu,
logger: config.Logger,
}
}
// VerifyToken verifies the validity of an ID token or access token.
// It performs comprehensive validation including format checks, blacklist verification,
// signature validation using JWKs, and standard claims validation.
func (v *Validator) VerifyToken(token string) error {
if token == "" {
return fmt.Errorf("invalid JWT format: token is empty")
}
if strings.Count(token, ".") != 2 {
return fmt.Errorf("invalid JWT format: expected JWT with 3 parts, got %d parts", strings.Count(token, ".")+1)
}
if len(token) < 10 {
return fmt.Errorf("token too short to be valid JWT")
}
// Check raw token blacklist
if v.tokenBlacklist != nil {
if blacklisted, exists := v.tokenBlacklist.Get(token); exists && blacklisted != nil {
return fmt.Errorf("token is blacklisted (raw string) in cache")
}
}
// Parse JWT for further validation
parsedJWT, parseErr := v.parseJWT(token)
if parseErr != nil {
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
}
tokenType := v.determineTokenType(parsedJWT)
// Check token cache FIRST - if token is already verified and cached, return immediately
// This prevents false positives when multiple goroutines validate the same token concurrently
if claims, exists := v.tokenCache.GetCachedToken(token); exists && len(claims) > 0 {
return nil
}
// Check JTI blacklist for replay detection
if err := v.checkJTIBlacklist(parsedJWT, token); err != nil {
return err
}
// Rate limiting check
if !v.checkRateLimit() {
return fmt.Errorf("rate limit exceeded")
}
// Verify signature and claims
if err := v.VerifyJWTSignatureAndClaims(parsedJWT, token); err != nil {
if !strings.Contains(err.Error(), "token has expired") {
v.logErrorf("%s token verification failed: %v", tokenType, err)
}
return err
}
// Cache verified token
v.cacheVerifiedToken(token, parsedJWT.Claims)
// Add JTI to blacklist for replay prevention
v.addJTIToBlacklist(parsedJWT)
return nil
}
// VerifyJWTSignatureAndClaims verifies JWT signature using provider's public keys and validates standard claims
func (v *Validator) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
v.logDebugf("Verifying JWT signature and claims")
// Get JWKS URL
v.metadataMu.RLock()
jwksURL := v.jwksURL
v.metadataMu.RUnlock()
// Get JWKS from cache
jwks, err := v.getJWKS(context.Background(), jwksURL)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
// Extract key ID and algorithm from token header
kid, ok := jwt.Header["kid"].(string)
if !ok {
return fmt.Errorf("missing key ID in token header")
}
alg, ok := jwt.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing algorithm in token header")
}
// Find matching key in JWKS
matchingKey := v.findMatchingKey(jwks, kid)
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
// Convert JWK to PEM and verify signature
if err := v.verifyTokenSignature(token, matchingKey, alg); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
// Detect token type and validate claims
isIDToken := v.detectTokenType(jwt, token)
expectedAudience := v.audience
if isIDToken {
expectedAudience = v.clientID
}
// Verify standard claims
v.metadataMu.RLock()
issuerURL := v.issuerURL
v.metadataMu.RUnlock()
if err := v.verifyStandardClaims(jwt, issuerURL, expectedAudience); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
return nil
}
// detectTokenType efficiently detects whether a token is an ID token or access token
func (v *Validator) detectTokenType(jwt *JWT, token string) bool {
// Use first 32 chars of token as cache key
cacheKey := token
if len(token) > 32 {
cacheKey = token[:32]
}
// Check cache first
if v.tokenTypeCache != nil {
if cachedData, found := v.tokenTypeCache.Get(cacheKey); found {
if isIDToken, ok := cachedData["is_id_token"].(bool); ok {
return isIDToken
}
}
}
// Check for ID token indicators
isIDToken := false
// 1. Check 'nonce' claim (definitive for ID tokens)
if nonce, ok := jwt.Claims["nonce"]; ok {
if _, ok := nonce.(string); ok {
v.cacheTokenType(cacheKey, true)
return true
}
}
// 2. Check 'typ' header for "at+jwt" (definitive for access tokens)
if typ, ok := jwt.Header["typ"].(string); ok && typ == "at+jwt" {
v.cacheTokenType(cacheKey, false)
return false
}
// 3. Check 'token_use' claim
if tokenUse, ok := jwt.Claims["token_use"].(string); ok {
switch tokenUse {
case "id":
v.cacheTokenType(cacheKey, true)
return true
case "access":
v.cacheTokenType(cacheKey, false)
return false
}
}
// 4. Check 'scope' claim (indicator for access tokens)
if scope, ok := jwt.Claims["scope"]; ok {
if _, ok := scope.(string); ok {
v.cacheTokenType(cacheKey, false)
return false
}
}
// 5. Check audience matching
if aud, ok := jwt.Claims["aud"]; ok {
if audStr, ok := aud.(string); ok && audStr == v.clientID {
isIDToken = true
} else if audArr, ok := aud.([]interface{}); ok && len(audArr) == 1 {
for _, val := range audArr {
if str, ok := val.(string); ok && str == v.clientID {
isIDToken = true
break
}
}
}
}
v.cacheTokenType(cacheKey, isIDToken)
return isIDToken
}
// Helper methods (stubs for interface compatibility)
func (v *Validator) parseJWT(token string) (*JWT, error) {
// This would call the actual JWT parsing function
// For now, returning a stub
return nil, fmt.Errorf("parseJWT not implemented")
}
func (v *Validator) determineTokenType(jwt *JWT) string {
if v.detectTokenType(jwt, "") {
return TokenTypeID
}
return TokenTypeAccess
}
func (v *Validator) checkJTIBlacklist(jwt *JWT, token string) error {
if v.disableReplayDetection {
return nil
}
if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" {
// Skip for test tokens
if !strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
if v.tokenBlacklist != nil {
if blacklisted, exists := v.tokenBlacklist.Get(jti); exists && blacklisted != nil {
return fmt.Errorf("token replay detected (jti: %s) in cache", jti)
}
}
}
}
return nil
}
func (v *Validator) checkRateLimit() bool {
// Interface method call would go here
return true
}
func (v *Validator) cacheVerifiedToken(token string, claims map[string]interface{}) {
v.tokenCache.CacheToken(token, claims)
}
func (v *Validator) addJTIToBlacklist(jwt *JWT) {
if v.disableReplayDetection {
return
}
jti, ok := jwt.Claims["jti"].(string)
if !ok || jti == "" {
return
}
if v.tokenBlacklist != nil {
v.tokenBlacklist.Set(jti, map[string]interface{}{
"blacklisted_at": time.Now().Unix(),
"reason": "jti_replay_prevention",
})
}
}
func (v *Validator) cacheTokenType(cacheKey string, isIDToken bool) {
if v.tokenTypeCache != nil {
v.tokenTypeCache.Set(cacheKey, map[string]interface{}{
"is_id_token": isIDToken,
"cached_at": time.Now().Unix(),
})
}
}
func (v *Validator) getJWKS(ctx context.Context, jwksURL string) (*JWKS, error) {
// Interface method call would go here
return nil, fmt.Errorf("getJWKS not implemented")
}
func (v *Validator) findMatchingKey(jwks *JWKS, kid string) *JWK {
if jwks == nil {
return nil
}
for _, key := range jwks.Keys {
if key.Kid == kid {
return &key
}
}
return nil
}
func (v *Validator) verifyTokenSignature(token string, key *JWK, alg string) error {
// Interface method call would go here
return fmt.Errorf("verifyTokenSignature not implemented")
}
func (v *Validator) verifyStandardClaims(jwt *JWT, issuer, audience string) error {
// Interface method call would go here
return fmt.Errorf("verifyStandardClaims not implemented")
}
func (v *Validator) logDebugf(format string, args ...interface{}) {
// Logger interface call would go here
}
func (v *Validator) logErrorf(format string, args ...interface{}) {
// Logger interface call would go here
}
-139
View File
@@ -1,139 +0,0 @@
// Package token provides token verification and management functionality
package token
import (
"fmt"
"strings"
"time"
traefikoidc "github.com/lukaszraczylo/traefikoidc"
)
// Verifier handles token verification operations
type Verifier struct {
tokenCache TokenCache
tokenBlacklist Cache
jwkCache JWKCache
limiter RateLimiter
logger Logger
}
// Cache interface for token operations
type Cache interface {
Get(key string) (interface{}, bool)
Set(key string, value interface{}, ttl time.Duration)
}
// TokenCache interface for verified token storage
type TokenCache interface {
Get(key string) (map[string]interface{}, bool)
Set(key string, claims map[string]interface{}, ttl time.Duration)
}
// JWKCache interface for key management
type JWKCache interface {
GetJWKS(providerURL string) (*traefikoidc.JWKSet, error)
}
// RateLimiter interface for request limiting
type RateLimiter interface {
Allow() bool
}
// Logger interface for logging
type Logger interface {
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// JWT represents a parsed JWT token
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
}
// NewVerifier creates a new token verifier
func NewVerifier(tokenCache TokenCache, tokenBlacklist Cache, jwkCache JWKCache, limiter RateLimiter, logger Logger) *Verifier {
return &Verifier{
tokenCache: tokenCache,
tokenBlacklist: tokenBlacklist,
jwkCache: jwkCache,
limiter: limiter,
logger: logger,
}
}
// VerifyToken verifies the validity of an ID token or access token
func (v *Verifier) VerifyToken(token string, clientID string, jwksURL string, issuerURL string) error {
if token == "" {
return fmt.Errorf("invalid JWT format: token is empty")
}
if strings.Count(token, ".") != 2 {
return fmt.Errorf("invalid JWT format: expected JWT with 3 parts, got %d parts", strings.Count(token, ".")+1)
}
if len(token) < 10 {
return fmt.Errorf("token too short to be valid JWT")
}
// Check blacklist
if v.tokenBlacklist != nil {
if blacklisted, exists := v.tokenBlacklist.Get(token); exists && blacklisted != nil {
return fmt.Errorf("token is blacklisted")
}
}
// Check cache first
if claims, exists := v.tokenCache.Get(token); exists && len(claims) > 0 {
return nil
}
// Rate limiting
if !v.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
// Parse and verify JWT
jwt, err := v.parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
if err := v.verifyJWTSignatureAndClaims(jwt, token, clientID, jwksURL, issuerURL); err != nil {
return err
}
// Cache successful verification
v.cacheVerifiedToken(token, jwt.Claims)
return nil
}
// parseJWT parses a JWT token into its components
func (v *Verifier) parseJWT(token string) (*JWT, error) {
// This would contain the actual JWT parsing logic
// For now, return a placeholder
return &JWT{
Header: make(map[string]interface{}),
Claims: make(map[string]interface{}),
}, nil
}
// verifyJWTSignatureAndClaims verifies JWT signature and claims
func (v *Verifier) verifyJWTSignatureAndClaims(jwt *JWT, token string, clientID string, jwksURL string, issuerURL string) error {
// This would contain the actual signature verification logic
// For now, return nil (placeholder)
return nil
}
// cacheVerifiedToken stores a successfully verified token
func (v *Verifier) cacheVerifiedToken(token string, claims map[string]interface{}) {
if expClaim, ok := claims["exp"].(float64); ok {
expirationTime := time.Unix(int64(expClaim), 0)
duration := time.Until(expirationTime)
if duration > 0 {
v.tokenCache.Set(token, claims, duration)
}
}
}
-457
View File
@@ -1,457 +0,0 @@
package token
import (
"strings"
"testing"
"time"
traefikoidc "github.com/lukaszraczylo/traefikoidc"
)
// Mock implementations for testing
type MockTokenCache struct {
data map[string]map[string]interface{}
}
func (m *MockTokenCache) Get(key string) (map[string]interface{}, bool) {
if m.data == nil {
return nil, false
}
value, exists := m.data[key]
return value, exists
}
func (m *MockTokenCache) Set(key string, claims map[string]interface{}, ttl time.Duration) {
if m.data == nil {
m.data = make(map[string]map[string]interface{})
}
m.data[key] = claims
}
type MockCache struct {
data map[string]interface{}
}
func (m *MockCache) Get(key string) (interface{}, bool) {
if m.data == nil {
return nil, false
}
value, exists := m.data[key]
return value, exists
}
func (m *MockCache) Set(key string, value interface{}, ttl time.Duration) {
if m.data == nil {
m.data = make(map[string]interface{})
}
m.data[key] = value
}
type MockJWKCache struct{}
func (m *MockJWKCache) GetJWKS(providerURL string) (*traefikoidc.JWKSet, error) {
return &traefikoidc.JWKSet{
Keys: []traefikoidc.JWK{
{
Kid: "test-key",
Kty: "RSA",
Use: "sig",
Alg: "RS256",
},
},
}, nil
}
type MockRateLimiter struct {
allow bool
}
func (m *MockRateLimiter) Allow() bool {
return m.allow
}
type MockLogger struct {
debugMessages []string
errorMessages []string
}
func (m *MockLogger) Debugf(format string, args ...interface{}) {
m.debugMessages = append(m.debugMessages, format)
}
func (m *MockLogger) Errorf(format string, args ...interface{}) {
m.errorMessages = append(m.errorMessages, format)
}
func TestNewVerifier(t *testing.T) {
tokenCache := &MockTokenCache{}
tokenBlacklist := &MockCache{}
jwkCache := &MockJWKCache{}
limiter := &MockRateLimiter{allow: true}
logger := &MockLogger{}
verifier := NewVerifier(tokenCache, tokenBlacklist, jwkCache, limiter, logger)
if verifier == nil {
t.Fatal("NewVerifier returned nil")
}
if verifier.tokenCache != tokenCache {
t.Error("TokenCache not set correctly")
}
if verifier.tokenBlacklist != tokenBlacklist {
t.Error("TokenBlacklist not set correctly")
}
// Note: Interface comparison would require reflecting on the actual implementation
// For now, we just check that the field was set to something non-nil
if verifier.jwkCache == nil {
t.Error("JWKCache not set correctly")
}
if verifier.limiter != limiter {
t.Error("RateLimiter not set correctly")
}
if verifier.logger != logger {
t.Error("Logger not set correctly")
}
}
func TestVerifierBasicFunctionality(t *testing.T) {
tokenCache := &MockTokenCache{}
tokenBlacklist := &MockCache{}
jwkCache := &MockJWKCache{}
limiter := &MockRateLimiter{allow: true}
logger := &MockLogger{}
verifier := NewVerifier(tokenCache, tokenBlacklist, jwkCache, limiter, logger)
// Test that the verifier was created successfully
if verifier == nil {
t.Fatal("Expected non-nil verifier")
}
}
func TestJWKSStructure(t *testing.T) {
jwks := &traefikoidc.JWKSet{
Keys: []traefikoidc.JWK{
{
Kid: "test-key-1",
Kty: "RSA",
Use: "sig",
Alg: "RS256",
},
{
Kid: "test-key-2",
Kty: "RSA",
Use: "sig",
Alg: "RS256",
},
},
}
if len(jwks.Keys) != 2 {
t.Errorf("Expected 2 keys, got %d", len(jwks.Keys))
}
if jwks.Keys[0].Kid != "test-key-1" {
t.Errorf("Expected Kid 'test-key-1', got '%s'", jwks.Keys[0].Kid)
}
if jwks.Keys[1].Kid != "test-key-2" {
t.Errorf("Expected Kid 'test-key-2', got '%s'", jwks.Keys[1].Kid)
}
}
func TestJWKStructure(t *testing.T) {
jwk := traefikoidc.JWK{
Kid: "test-key",
Kty: "RSA",
Use: "sig",
Alg: "RS256",
N: "test-modulus",
E: "test-exponent",
}
if jwk.Kid != "test-key" {
t.Errorf("Expected Kid 'test-key', got '%s'", jwk.Kid)
}
if jwk.Kty != "RSA" {
t.Errorf("Expected Kty 'RSA', got '%s'", jwk.Kty)
}
if jwk.Use != "sig" {
t.Errorf("Expected Use 'sig', got '%s'", jwk.Use)
}
if jwk.Alg != "RS256" {
t.Errorf("Expected Alg 'RS256', got '%s'", jwk.Alg)
}
}
func TestVerifyToken(t *testing.T) {
tests := []struct {
name string
token string
clientID string
jwksURL string
issuerURL string
rateLimitAllow bool
cacheData map[string]map[string]interface{}
blacklistData map[string]interface{}
expectedError string
}{
{
name: "Empty token",
token: "",
clientID: "test-client",
jwksURL: "https://example.com/jwks",
issuerURL: "https://example.com",
rateLimitAllow: true,
expectedError: "invalid JWT format: token is empty",
},
{
name: "Invalid JWT format - too few parts",
token: "header.payload",
clientID: "test-client",
jwksURL: "https://example.com/jwks",
issuerURL: "https://example.com",
rateLimitAllow: true,
expectedError: "invalid JWT format: expected JWT with 3 parts, got 2 parts",
},
{
name: "Invalid JWT format - too many parts",
token: "header.payload.signature.extra",
clientID: "test-client",
jwksURL: "https://example.com/jwks",
issuerURL: "https://example.com",
rateLimitAllow: true,
expectedError: "invalid JWT format: expected JWT with 3 parts, got 4 parts",
},
{
name: "Token too short",
token: "a.b.c",
clientID: "test-client",
jwksURL: "https://example.com/jwks",
issuerURL: "https://example.com",
rateLimitAllow: true,
expectedError: "token too short to be valid JWT",
},
{
name: "Blacklisted token",
token: "valid.format.token",
clientID: "test-client",
jwksURL: "https://example.com/jwks",
issuerURL: "https://example.com",
rateLimitAllow: true,
blacklistData: map[string]interface{}{"valid.format.token": true},
expectedError: "token is blacklisted",
},
{
name: "Cached token - success",
token: "valid.format.token",
clientID: "test-client",
jwksURL: "https://example.com/jwks",
issuerURL: "https://example.com",
rateLimitAllow: true,
cacheData: map[string]map[string]interface{}{"valid.format.token": {"sub": "user123"}},
expectedError: "",
},
{
name: "Rate limit exceeded",
token: "valid.format.token",
clientID: "test-client",
jwksURL: "https://example.com/jwks",
issuerURL: "https://example.com",
rateLimitAllow: false,
expectedError: "rate limit exceeded",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenCache := &MockTokenCache{data: tt.cacheData}
tokenBlacklist := &MockCache{data: tt.blacklistData}
jwkCache := &MockJWKCache{}
limiter := &MockRateLimiter{allow: tt.rateLimitAllow}
logger := &MockLogger{}
verifier := NewVerifier(tokenCache, tokenBlacklist, jwkCache, limiter, logger)
err := verifier.VerifyToken(tt.token, tt.clientID, tt.jwksURL, tt.issuerURL)
if tt.expectedError == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tt.expectedError)
} else if !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("Expected error containing '%s', got: %v", tt.expectedError, err)
}
}
})
}
}
func TestParseJWT(t *testing.T) {
tokenCache := &MockTokenCache{}
tokenBlacklist := &MockCache{}
jwkCache := &MockJWKCache{}
limiter := &MockRateLimiter{allow: true}
logger := &MockLogger{}
verifier := NewVerifier(tokenCache, tokenBlacklist, jwkCache, limiter, logger)
// Test parseJWT with a valid format token
jwt, err := verifier.parseJWT("header.payload.signature")
if err != nil {
t.Errorf("Expected no error parsing JWT, got: %v", err)
}
if jwt == nil {
t.Error("Expected non-nil JWT object")
return
}
if jwt.Header == nil {
t.Error("Expected non-nil Header map")
}
if jwt.Claims == nil {
t.Error("Expected non-nil Claims map")
}
}
func TestVerifyJWTSignatureAndClaims(t *testing.T) {
tokenCache := &MockTokenCache{}
tokenBlacklist := &MockCache{}
jwkCache := &MockJWKCache{}
limiter := &MockRateLimiter{allow: true}
logger := &MockLogger{}
verifier := NewVerifier(tokenCache, tokenBlacklist, jwkCache, limiter, logger)
jwt := &JWT{
Header: map[string]interface{}{"alg": "RS256"},
Claims: map[string]interface{}{"sub": "user123", "exp": float64(time.Now().Add(time.Hour).Unix())},
}
// Test signature verification (currently returns nil - placeholder)
err := verifier.verifyJWTSignatureAndClaims(jwt, "test.token.here", "client-id", "https://example.com/jwks", "https://example.com")
if err != nil {
t.Errorf("Expected no error from placeholder verification, got: %v", err)
}
}
func TestCacheVerifiedToken(t *testing.T) {
tokenCache := &MockTokenCache{}
tokenBlacklist := &MockCache{}
jwkCache := &MockJWKCache{}
limiter := &MockRateLimiter{allow: true}
logger := &MockLogger{}
verifier := NewVerifier(tokenCache, tokenBlacklist, jwkCache, limiter, logger)
tests := []struct {
name string
token string
claims map[string]interface{}
expected bool
}{
{
name: "Valid expiration time",
token: "test-token-1",
claims: map[string]interface{}{"exp": float64(time.Now().Add(time.Hour).Unix())},
expected: true,
},
{
name: "Expired token",
token: "test-token-2",
claims: map[string]interface{}{"exp": float64(time.Now().Add(-time.Hour).Unix())},
expected: false,
},
{
name: "No expiration claim",
token: "test-token-3",
claims: map[string]interface{}{"sub": "user123"},
expected: false,
},
{
name: "Invalid expiration type",
token: "test-token-4",
claims: map[string]interface{}{"exp": "invalid"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear cache before test
tokenCache.data = make(map[string]map[string]interface{})
verifier.cacheVerifiedToken(tt.token, tt.claims)
_, exists := tokenCache.Get(tt.token)
if exists != tt.expected {
t.Errorf("Expected cache existence: %v, got: %v", tt.expected, exists)
}
})
}
}
func TestMockInterfaces(t *testing.T) {
// Test MockTokenCache
tokenCache := &MockTokenCache{}
claims := map[string]interface{}{"sub": "user123", "exp": 1234567890}
tokenCache.Set("test-token", claims, time.Hour)
retrieved, exists := tokenCache.Get("test-token")
if !exists {
t.Error("Expected token to exist in cache")
}
if retrieved["sub"] != "user123" {
t.Errorf("Expected sub 'user123', got '%v'", retrieved["sub"])
}
// Test MockCache
cache := &MockCache{}
cache.Set("test-key", "test-value", time.Hour)
value, exists := cache.Get("test-key")
if !exists {
t.Error("Expected key to exist in cache")
}
if value != "test-value" {
t.Errorf("Expected 'test-value', got '%v'", value)
}
// Test MockRateLimiter
limiter := &MockRateLimiter{allow: true}
if !limiter.Allow() {
t.Error("Expected rate limiter to allow request")
}
limiter.allow = false
if limiter.Allow() {
t.Error("Expected rate limiter to deny request")
}
// Test MockLogger
logger := &MockLogger{}
logger.Debugf("test debug message")
logger.Errorf("test error message")
if len(logger.debugMessages) != 1 {
t.Errorf("Expected 1 debug message, got %d", len(logger.debugMessages))
}
if len(logger.errorMessages) != 1 {
t.Errorf("Expected 1 error message, got %d", len(logger.errorMessages))
}
}
+91
View File
@@ -0,0 +1,91 @@
package utils
import (
"github.com/lukaszraczylo/traefikoidc/internal/cleanup"
"github.com/lukaszraczylo/traefikoidc/internal/recovery"
)
// LoggerInterface defines the common logger interface used across the package
type LoggerInterface interface {
Infof(format string, args ...interface{})
Debugf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// ============================================================================
// RECOVERY LOGGER WRAPPER
// ============================================================================
// recoveryLoggerWrapper wraps a logger to match recovery.Logger interface
type recoveryLoggerWrapper struct {
logger LoggerInterface
}
// WrapLoggerForRecovery wraps a logger for use with recovery modules
func WrapLoggerForRecovery(logger LoggerInterface) recovery.Logger {
return &recoveryLoggerWrapper{logger: logger}
}
// Logf logs an informational message
func (lw *recoveryLoggerWrapper) Logf(format string, args ...interface{}) {
if lw.logger != nil {
lw.logger.Infof(format, args...)
}
}
// ErrorLogf logs an error message
func (lw *recoveryLoggerWrapper) ErrorLogf(format string, args ...interface{}) {
if lw.logger != nil {
lw.logger.Errorf(format, args...)
}
}
// DebugLogf logs a debug message
func (lw *recoveryLoggerWrapper) DebugLogf(format string, args ...interface{}) {
if lw.logger != nil {
lw.logger.Debugf(format, args...)
}
}
// ============================================================================
// CLEANUP LOGGER WRAPPER
// ============================================================================
// cleanupLoggerWrapper wraps a logger to match cleanup.Logger interface
type cleanupLoggerWrapper struct {
logger LoggerInterface
}
// WrapLoggerForCleanup wraps a logger for use with cleanup modules
func WrapLoggerForCleanup(logger LoggerInterface) cleanup.Logger {
return &cleanupLoggerWrapper{logger: logger}
}
// Logf logs an informational message
func (lw *cleanupLoggerWrapper) Logf(format string, args ...interface{}) {
if lw.logger != nil {
lw.logger.Infof(format, args...)
}
}
// ErrorLogf logs an error message
func (lw *cleanupLoggerWrapper) ErrorLogf(format string, args ...interface{}) {
if lw.logger != nil {
lw.logger.Errorf(format, args...)
}
}
// DebugLogf logs a debug message
func (lw *cleanupLoggerWrapper) DebugLogf(format string, args ...interface{}) {
if lw.logger != nil {
lw.logger.Debugf(format, args...)
}
}
// ============================================================================
// SESSION LOGGER WRAPPER
// ============================================================================
// Note: Session logger wrapper is not included here because session.Logger
// has a different interface (Debug/Info/Warn/Error instead of Logf/ErrorLogf/DebugLogf).
// Each package should implement its own session logger adapter as needed.
+29 -4
View File
@@ -10,6 +10,12 @@ import (
"time"
)
const (
// metadataCacheVersion is incremented when cache format changes
// This ensures old cached data is automatically ignored
metadataCacheVersion = "v2"
)
// MetadataCache wraps UniversalCache for metadata operations
type MetadataCache struct {
cache *UniversalCache
@@ -17,6 +23,11 @@ type MetadataCache struct {
wg *sync.WaitGroup
}
// versionedKey adds version prefix to cache keys
func (mc *MetadataCache) versionedKey(key string) string {
return metadataCacheVersion + ":" + key
}
// MetadataCacheEntry for compatibility
type MetadataCacheEntry struct {
}
@@ -55,12 +66,14 @@ func (mc *MetadataCache) Set(providerURL string, metadata *ProviderMetadata, ttl
return fmt.Errorf("failed to marshal metadata: %w", err)
}
return mc.cache.Set(providerURL, data, ttl)
// Use versioned key to prevent stale data issues
return mc.cache.Set(mc.versionedKey(providerURL), data, ttl)
}
// Get retrieves provider metadata from cache
func (mc *MetadataCache) Get(providerURL string) (*ProviderMetadata, bool) {
value, exists := mc.cache.Get(providerURL)
// Use versioned key to prevent stale data issues
value, exists := mc.cache.Get(mc.versionedKey(providerURL))
if !exists {
mc.logger.Debugf("MetadataCache: MISS for %s", providerURL)
return nil, false
@@ -78,9 +91,21 @@ func (mc *MetadataCache) Get(providerURL string) (*ProviderMetadata, bool) {
return nil, false
}
// Debug: log first 100 chars of cached data to diagnose unmarshal issues
dataPreview := string(data)
if len(dataPreview) > 100 {
dataPreview = dataPreview[:100]
}
mc.logger.Debugf("MetadataCache: Attempting to unmarshal for %s, data preview: %s", providerURL, dataPreview)
var metadata ProviderMetadata
if err := json.Unmarshal(data, &metadata); err != nil {
mc.logger.Errorf("MetadataCache: Failed to unmarshal metadata for %s: %v", providerURL, err)
// Graceful degradation: corrupt data is treated as cache miss
mc.logger.Errorf("MetadataCache: Corrupt data detected for %s: %v (preview: %s) - deleting and treating as miss", providerURL, err, dataPreview)
// Delete corrupt entry to prevent repeated errors (use versioned key)
mc.cache.Delete(mc.versionedKey(providerURL))
return nil, false
}
@@ -183,7 +208,7 @@ func (mc *MetadataCache) CleanupExpired() {
// Delete removes an entry from the cache
func (mc *MetadataCache) Delete(key string) {
mc.cache.Delete(key)
mc.cache.Delete(mc.versionedKey(key))
}
// Mutex returns the cache mutex for testing