Files
traefikoidc/config/loader.go
T
lukaszraczylo e64fc7f730 Add redis support for distributed caching (#83)
* Add redis support for distributed caching

* Move towards the self-provided Redis connection pool and RESP protocol implementation.
Official redis client library won't work with yaegi.

* fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* ... and another all nighter.

* fixup! ... and another all nighter.

* fixup! fixup! ... and another all nighter.

* fixup! fixup! fixup! ... and another all nighter.

* Resolve issue #85 by adding ability to set custom claims in JWT tokens

* Remove redundant validation in auth middleware ( issue #89 )

* Add ability to set cookie prefix for session cookies ( #87 )

* fixup! Add ability to set cookie prefix for session cookies ( #87 )

* Add ability to set cookie max age - issue #91

* Potential fix for code scanning alert no. 10: Size computation for allocation may overflow

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fixup! Merge main into 0.8.0-redis: resolve conflicts

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-11-30 02:18:46 +00:00

397 lines
11 KiB
Go

// Package config provides configuration loading and merging logic
package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"github.com/lukaszraczylo/traefikoidc/internal/features"
"gopkg.in/yaml.v3"
)
// ConfigLoader handles loading configuration from various sources
type ConfigLoader struct {
migrator *ConfigMigrator
envPrefix string
configPaths []string
}
// NewConfigLoader creates a new configuration loader
func NewConfigLoader() *ConfigLoader {
return &ConfigLoader{
migrator: NewConfigMigrator(),
envPrefix: "TRAEFIKOIDC_",
configPaths: getDefaultConfigPaths(),
}
}
// getDefaultConfigPaths returns default configuration file paths to check
func getDefaultConfigPaths() []string {
return []string{
"traefik-oidc.yaml",
"traefik-oidc.yml",
"traefik-oidc.json",
"config.yaml",
"config.yml",
"config.json",
"/etc/traefik-oidc/config.yaml",
"/etc/traefik-oidc/config.json",
}
}
// Load loads configuration from all available sources
func (l *ConfigLoader) Load() (*UnifiedConfig, error) {
// Start with defaults
config := NewUnifiedConfig()
// Try to load from file
if fileConfig, err := l.LoadFromFile(); err == nil && fileConfig != nil {
config = l.mergeConfigs(config, fileConfig)
}
// Load from environment variables
l.LoadFromEnv(config)
// Validate the final configuration
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("configuration validation failed: %w", err)
}
return config, nil
}
// LoadFromFile loads configuration from a file
func (l *ConfigLoader) LoadFromFile(paths ...string) (*UnifiedConfig, error) {
// Use provided paths or default paths
searchPaths := paths
if len(searchPaths) == 0 {
searchPaths = l.configPaths
}
// Check for config file in environment variable
if envPath := os.Getenv(l.envPrefix + "CONFIG_FILE"); envPath != "" {
searchPaths = append([]string{envPath}, searchPaths...)
}
// Try each path
for _, path := range searchPaths {
if _, err := os.Stat(path); err == nil {
return l.loadFile(path)
}
}
// No config file found, not an error (use defaults)
return nil, nil
}
// loadFile loads a specific configuration file
func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
// Clean and validate path to prevent traversal attacks
cleanPath := filepath.Clean(path)
// Check for path traversal attempts
if strings.Contains(cleanPath, "..") {
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
}
// Ensure the path is within expected directories (current dir or subdirs)
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
}
// Read the file with validated path
data, err := os.ReadFile(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
}
// Check if unified config is enabled
if features.IsUnifiedConfigEnabled() {
// Use migrator to handle any version
config, warnings, err := l.migrator.Migrate(data)
if err != nil {
return nil, fmt.Errorf("failed to migrate config from %s: %w", path, err)
}
// Log warnings
for _, warning := range warnings {
// In production, use proper logging
fmt.Printf("Config Warning (%s): %s\n", path, warning)
}
return config, nil
}
// Legacy path: load old config and convert
ext := strings.ToLower(filepath.Ext(path))
var oldConfig Config
switch ext {
case ".json":
if err := json.Unmarshal(data, &oldConfig); err != nil {
return nil, fmt.Errorf("failed to parse JSON config: %w", err)
}
case ".yaml", ".yml":
if err := yaml.Unmarshal(data, &oldConfig); err != nil {
return nil, fmt.Errorf("failed to parse YAML config: %w", err)
}
default:
return nil, fmt.Errorf("unsupported config file extension: %s", ext)
}
return FromOldConfig(&oldConfig), nil
}
// LoadFromEnv loads configuration from environment variables
func (l *ConfigLoader) LoadFromEnv(config *UnifiedConfig) {
// Provider configuration
l.loadEnvString(&config.Provider.IssuerURL, "PROVIDER_ISSUER_URL", "PROVIDER_URL")
l.loadEnvString(&config.Provider.ClientID, "PROVIDER_CLIENT_ID", "CLIENT_ID")
l.loadEnvString(&config.Provider.ClientSecret, "PROVIDER_CLIENT_SECRET", "CLIENT_SECRET")
l.loadEnvString(&config.Provider.RedirectURL, "PROVIDER_REDIRECT_URL", "CALLBACK_URL")
l.loadEnvString(&config.Provider.LogoutURL, "PROVIDER_LOGOUT_URL", "LOGOUT_URL")
l.loadEnvString(&config.Provider.PostLogoutRedirectURI, "PROVIDER_POST_LOGOUT_URI", "POST_LOGOUT_REDIRECT_URI")
l.loadEnvStringSlice(&config.Provider.Scopes, "PROVIDER_SCOPES", "SCOPES")
l.loadEnvBool(&config.Provider.OverrideScopes, "PROVIDER_OVERRIDE_SCOPES", "OVERRIDE_SCOPES")
// Session configuration
l.loadEnvString(&config.Session.Name, "SESSION_NAME")
l.loadEnvInt(&config.Session.MaxAge, "SESSION_MAX_AGE")
l.loadEnvString(&config.Session.Secret, "SESSION_SECRET")
l.loadEnvString(&config.Session.EncryptionKey, "SESSION_ENCRYPTION_KEY")
l.loadEnvString(&config.Session.Domain, "SESSION_DOMAIN", "COOKIE_DOMAIN")
l.loadEnvBool(&config.Session.Secure, "SESSION_SECURE")
l.loadEnvBool(&config.Session.HttpOnly, "SESSION_HTTP_ONLY")
l.loadEnvString(&config.Session.SameSite, "SESSION_SAME_SITE")
// Security configuration
l.loadEnvBool(&config.Security.ForceHTTPS, "SECURITY_FORCE_HTTPS", "FORCE_HTTPS")
l.loadEnvBool(&config.Security.EnablePKCE, "SECURITY_ENABLE_PKCE", "ENABLE_PKCE")
l.loadEnvStringSlice(&config.Security.AllowedUsers, "SECURITY_ALLOWED_USERS", "ALLOWED_USERS")
l.loadEnvStringSlice(&config.Security.AllowedUserDomains, "SECURITY_ALLOWED_DOMAINS", "ALLOWED_USER_DOMAINS")
l.loadEnvStringSlice(&config.Security.AllowedRolesAndGroups, "SECURITY_ALLOWED_ROLES", "ALLOWED_ROLES_AND_GROUPS")
l.loadEnvStringSlice(&config.Security.ExcludedURLs, "SECURITY_EXCLUDED_URLS", "EXCLUDED_URLS")
// Cache configuration
l.loadEnvBool(&config.Cache.Enabled, "CACHE_ENABLED")
l.loadEnvString(&config.Cache.Type, "CACHE_TYPE")
l.loadEnvInt(&config.Cache.MaxEntries, "CACHE_MAX_ENTRIES")
// MaxEntrySize is int64, skip for now
// Rate limiting
l.loadEnvBool(&config.RateLimit.Enabled, "RATELIMIT_ENABLED")
l.loadEnvInt(&config.RateLimit.RequestsPerSecond, "RATELIMIT_RPS", "RATE_LIMIT")
l.loadEnvInt(&config.RateLimit.Burst, "RATELIMIT_BURST")
// Logging
l.loadEnvString(&config.Logging.Level, "LOGGING_LEVEL", "LOG_LEVEL")
l.loadEnvString(&config.Logging.Format, "LOGGING_FORMAT")
l.loadEnvString(&config.Logging.Output, "LOGGING_OUTPUT")
// Redis configuration (already handled by its own LoadFromEnv)
config.Redis.LoadFromEnv()
// Feature flags
features.GetManager().LoadFromEnv()
}
// Helper methods for environment variable loading
func (l *ConfigLoader) loadEnvString(target *string, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = value
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = value
return
}
}
}
func (l *ConfigLoader) loadEnvBool(target *bool, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = strings.ToLower(value) == "true" || value == "1"
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = strings.ToLower(value) == "true" || value == "1"
return
}
}
}
func (l *ConfigLoader) loadEnvInt(target *int, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
var i int
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
*target = i
return
}
}
// Try without prefix
if value := os.Getenv(key); value != "" {
var i int
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
*target = i
return
}
}
}
}
func (l *ConfigLoader) loadEnvStringSlice(target *[]string, keys ...string) {
for _, key := range keys {
if value := os.Getenv(l.envPrefix + key); value != "" {
*target = splitAndTrim(value)
return
}
// Try without prefix
if value := os.Getenv(key); value != "" {
*target = splitAndTrim(value)
return
}
}
}
func splitAndTrim(s string) []string {
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
if trimmed := strings.TrimSpace(part); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
// mergeConfigs merges two configurations, with source overriding target
func (l *ConfigLoader) mergeConfigs(target, source *UnifiedConfig) *UnifiedConfig {
if source == nil {
return target
}
if target == nil {
return source
}
// Use reflection for deep merge
l.mergeStructs(reflect.ValueOf(target).Elem(), reflect.ValueOf(source).Elem())
return target
}
// mergeStructs recursively merges two structs
func (l *ConfigLoader) mergeStructs(target, source reflect.Value) {
for i := 0; i < source.NumField(); i++ {
sourceField := source.Field(i)
targetField := target.Field(i)
// Skip if source field is zero value
if isZeroValue(sourceField) {
continue
}
switch sourceField.Kind() {
case reflect.Struct:
// Recursively merge structs
l.mergeStructs(targetField, sourceField)
case reflect.Slice:
// Replace slice if source has values
if sourceField.Len() > 0 {
targetField.Set(sourceField)
}
case reflect.Map:
// Merge maps
if !sourceField.IsNil() {
if targetField.IsNil() {
targetField.Set(reflect.MakeMap(sourceField.Type()))
}
for _, key := range sourceField.MapKeys() {
targetField.SetMapIndex(key, sourceField.MapIndex(key))
}
}
default:
// Replace value
targetField.Set(sourceField)
}
}
}
// isZeroValue checks if a reflect.Value is a zero value
func isZeroValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return v.IsNil()
case reflect.Slice, reflect.Map:
return v.IsNil() || v.Len() == 0
case reflect.Struct:
// Check if all fields are zero
for i := 0; i < v.NumField(); i++ {
if !isZeroValue(v.Field(i)) {
return false
}
}
return true
default:
zero := reflect.Zero(v.Type())
return reflect.DeepEqual(v.Interface(), zero.Interface())
}
}
// SaveToFile saves the configuration to a file
func (l *ConfigLoader) SaveToFile(config *UnifiedConfig, path string) error {
// Clean and validate path to prevent traversal attacks
cleanPath := filepath.Clean(path)
// Check for path traversal attempts
if strings.Contains(cleanPath, "..") {
return fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
}
// Ensure the path is within expected directories
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
}
ext := strings.ToLower(filepath.Ext(absPath))
var data []byte
switch ext {
case ".json":
data, err = json.MarshalIndent(config, "", " ")
case ".yaml", ".yml":
data, err = yaml.Marshal(config)
default:
return fmt.Errorf("unsupported file extension: %s", ext)
}
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create directory if it doesn't exist with secure permissions
dir := filepath.Dir(absPath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Write file with secure permissions (owner read/write only)
if err := os.WriteFile(absPath, data, 0600); err != nil {
return fmt.Errorf("failed to write config file %s: %w", absPath, err)
}
return nil
}