Files
traefikoidc/internal/providers/registry.go
T
lukaszraczylo c3f23cb99b Release 0.7.5 (#70)
* Resolve issue with opaque tokens not being parsed correctly

* Increase test coverage

* Further improvements to test coverage and code quality

* Add new providers.

* fixup! Add new providers.

* Cleanup.

* fixup! Cleanup.

* fixup! fixup! Cleanup.

* fixup! fixup! fixup! Cleanup.

* fixup! fixup! fixup! fixup! Cleanup.

* Memory management optimisation

24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference)

* Pooling cleanup.
2025-10-01 12:13:10 +01:00

172 lines
4.7 KiB
Go

package providers
import (
"net/url"
"strings"
"sync"
)
// ProviderRegistry manages a collection of OIDC provider implementations.
// It provides thread-safe access to provider instances and caches detection results.
type ProviderRegistry struct {
cache map[string]OIDCProvider
typeMap map[ProviderType]OIDCProvider
providers []OIDCProvider
mu sync.RWMutex
// Bounded cache configuration to prevent memory leaks
maxCacheSize int
cacheCount int
}
// NewProviderRegistry creates and initializes a new ProviderRegistry.
func NewProviderRegistry() *ProviderRegistry {
return &ProviderRegistry{
providers: make([]OIDCProvider, 0),
cache: make(map[string]OIDCProvider),
typeMap: make(map[ProviderType]OIDCProvider),
maxCacheSize: 1000, // Prevent unbounded cache growth
cacheCount: 0,
}
}
// RegisterProvider adds a new provider to the registry.
// It maintains both a list of providers and a type-to-provider mapping for efficient lookups.
func (r *ProviderRegistry) RegisterProvider(provider OIDCProvider) {
r.mu.Lock()
defer r.mu.Unlock()
r.providers = append(r.providers, provider)
r.typeMap[provider.GetType()] = provider
}
// GetProviderByType retrieves a provider instance by its type.
// Returns nil if the provider type is not registered.
func (r *ProviderRegistry) GetProviderByType(providerType ProviderType) OIDCProvider {
r.mu.RLock()
defer r.mu.RUnlock()
return r.typeMap[providerType]
}
// GetRegisteredProviders returns a slice of all registered provider types.
func (r *ProviderRegistry) GetRegisteredProviders() []ProviderType {
r.mu.RLock()
defer r.mu.RUnlock()
types := make([]ProviderType, 0, len(r.typeMap))
for providerType := range r.typeMap {
types = append(types, providerType)
}
return types
}
// ClearCache removes all cached provider detection results.
// This can be useful for testing or when provider configuration changes.
func (r *ProviderRegistry) ClearCache() {
r.mu.Lock()
defer r.mu.Unlock()
r.cache = make(map[string]OIDCProvider)
r.cacheCount = 0
}
// evictOldestCacheEntry removes the first cache entry when cache is full
// This is a simple eviction strategy - in production, LRU might be preferred
func (r *ProviderRegistry) evictOldestCacheEntry() {
// Simple eviction: remove first entry found
for key := range r.cache {
delete(r.cache, key)
r.cacheCount--
break
}
}
// DetectProvider identifies the appropriate OIDC provider for an issuer URL.
// Uses double-checked locking pattern to avoid race conditions while caching results.
func (r *ProviderRegistry) DetectProvider(issuerURL string) OIDCProvider {
r.mu.RLock()
if provider, found := r.cache[issuerURL]; found {
r.mu.RUnlock()
return provider
}
r.mu.RUnlock()
r.mu.Lock()
defer r.mu.Unlock()
if provider, found := r.cache[issuerURL]; found {
return provider
}
detectedProvider := r.detectProviderUnsafe(issuerURL)
// Check if cache is full and evict if necessary
if r.cacheCount >= r.maxCacheSize {
r.evictOldestCacheEntry()
}
r.cache[issuerURL] = detectedProvider
r.cacheCount++
return detectedProvider
}
// detectProviderUnsafe performs the actual provider detection logic.
// This method assumes the caller holds the appropriate lock and should not be called directly.
func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
normalizedURL, err := url.Parse(issuerURL)
if err != nil {
return nil
}
// Check if the URL has a valid scheme and host
if normalizedURL.Scheme == "" || normalizedURL.Host == "" {
return nil
}
// Convert host to lowercase for case-insensitive matching
host := strings.ToLower(normalizedURL.Host)
for _, p := range r.providers {
switch p.GetType() {
case ProviderTypeGoogle:
if strings.Contains(host, "accounts.google.com") {
return p
}
case ProviderTypeAzure:
if strings.Contains(host, "login.microsoftonline.com") || strings.Contains(host, "sts.windows.net") {
return p
}
case ProviderTypeGitHub:
if strings.Contains(host, "github.com") {
return p
}
case ProviderTypeAuth0:
if strings.Contains(host, ".auth0.com") {
return p
}
case ProviderTypeOkta:
if strings.Contains(host, ".okta.com") || strings.Contains(host, ".oktapreview.com") || strings.Contains(host, ".okta-emea.com") {
return p
}
case ProviderTypeKeycloak:
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/auth/realms/") {
return p
}
case ProviderTypeAWSCognito:
if strings.Contains(host, "cognito-idp") && strings.Contains(host, ".amazonaws.com") {
return p
}
case ProviderTypeGitLab:
if strings.Contains(host, "gitlab.com") {
return p
}
}
}
for _, p := range r.providers {
if p.GetType() == ProviderTypeGeneric {
return p
}
}
return nil
}