Files
traefikoidc/utilities.go
T
lukaszraczylo 17e3f8ef62 fix: snapshot patterns for refresh-tracker and metadata URLs
Two related lock-free snapshot refactors addressing the remaining
post-v1.0.16 code-review findings.

1. refreshAttemptTracker: per-field atomic.Load/Store -> atomic.Value
   snapshot of *attemptState (refresh_coordinator.go).

   Previously each tracker held five independently-atomic fields. The
   cooldown-exit reset wrote cooldownEndNano = 0 first, then separately
   stored attempts = 1 and windowStartNano = now. A concurrent
   isInCooldown call could observe cooldownEndNano = 0 (reset just
   completed) with attempts still at MaxRefreshAttempts, immediately
   triggering a fresh cooldown — a benign double-trigger race that
   nonetheless meant the state machine had observable intermediate
   states.

   New design: state is a *attemptState (immutable) published via
   atomic.Value. All transitions (record/success/failure/window-reset/
   cooldown-enter/cooldown-exit) go through mutateState, which runs a
   CAS loop: load current snapshot -> construct fresh snapshot ->
   CompareAndSwap. Either the entire new state publishes or none of
   it does — no intermediate visibility, no cross-field race.

   Under Yaegi this collapses 3-5 per-field atomic dispatches into one
   atomic.Value.Load on the read path. Write paths pay an extra
   allocation for the new snapshot but avoid the cross-field hazard.

2. MetadataSnapshot: hot-path readers use atomic.Value instead of
   metadataMu.RLock (middleware.go, types.go, main.go, utilities.go).

   middleware.ServeHTTP previously took metadataMu.RLock on every
   non-bypass request to read the single field issuerURL. Under Yaegi
   each RLock acquisition costs 1-5ms of interpreter dispatch.
   updateMetadataEndpoints now also publishes an immutable
   *MetadataSnapshot via atomic.Value; the hot-path reader loads it
   in one op via t.metadataSnap(). Falls back to the legacy
   metadataMu.RLock pattern when the snapshot is unpublished (some
   test setups initialize the struct fields directly without going
   through updateMetadataEndpoints).

   Less-frequent callers (helpers, logout, token_introspection) still
   take metadataMu.RLock and are unchanged. The snapshot strictly
   subsets the metadataMu-protected fields, so those readers see
   identical data.

Note on atomic.Pointer[T]: this would have been the cleaner type but
yaegi v0.16.1's stdlib (used by traefik:v3.7.1) exposes only the
legacy unsafe.Pointer-based atomic primitives — no generic Pointer[T].
atomic.Value provides the same semantics via interface{} + type assert.

All tests pass with -race; golangci-lint clean.
2026-05-23 11:31:51 +01:00

357 lines
11 KiB
Go

// Package traefikoidc provides OIDC authentication middleware for Traefik.
// This file contains utility/helper methods extracted from main.go for better code organization.
package traefikoidc
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"html"
"net/http"
"runtime"
"strings"
"time"
)
// metadataSnap returns the most recently published *MetadataSnapshot, or nil
// if metadata has not yet been resolved. Single atomic.Value.Load — the hot
// ServeHTTP path uses this instead of acquiring metadataMu.RLock, which under
// Yaegi pays 1-5ms of interpreter-dispatch overhead per acquisition.
func (t *TraefikOidc) metadataSnap() *MetadataSnapshot {
v := t.metadataSnapshot.Load()
if v == nil {
return nil
}
s, _ := v.(*MetadataSnapshot)
return s
}
// safeLogDebug provides nil-safe logging for debug messages
func (t *TraefikOidc) safeLogDebug(msg string) {
if t.logger != nil {
t.logger.Debug("%s", msg)
}
}
// safeLogDebugf provides nil-safe logging for formatted debug messages
func (t *TraefikOidc) safeLogDebugf(format string, args ...interface{}) {
if t.logger != nil {
t.logger.Debugf(format, args...)
}
}
// safeLogError provides nil-safe logging for error messages
func (t *TraefikOidc) safeLogError(msg string) {
if t.logger != nil {
t.logger.Error("%s", msg)
}
}
// safeLogErrorf provides nil-safe logging for formatted error messages
func (t *TraefikOidc) safeLogErrorf(format string, args ...interface{}) {
if t.logger != nil {
t.logger.Errorf(format, args...)
}
}
// safeLogInfo provides nil-safe logging for info messages
func (t *TraefikOidc) safeLogInfo(msg string) {
if t.logger != nil {
t.logger.Info("%s", msg)
}
}
// isAllowedUser checks if a user identifier is authorized based on the configured user identifier claim.
// When using email as the identifier (default), it validates against allowedUsers and allowedUserDomains.
// When using non-email identifiers (sub, oid, upn, etc.), it only validates against allowedUsers
// since domain-based validation doesn't apply to non-email identifiers.
//
// Parameters:
// - userIdentifier: The user identifier to validate (email, sub, oid, upn, etc.).
//
// Returns:
// - true if the user is authorized, false otherwise.
func (t *TraefikOidc) isAllowedUser(userIdentifier string) bool {
// If no restrictions are configured, allow all authenticated users
if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 {
return true
}
// Check if user is explicitly allowed
if len(t.allowedUsers) > 0 {
_, userAllowed := t.allowedUsers[strings.ToLower(userIdentifier)]
if userAllowed {
t.logger.Debugf("User identifier %s is explicitly allowed in allowedUsers", userIdentifier)
return true
}
}
// For email-based identifiers, also check domain restrictions
// Only apply domain validation if using email as identifier AND identifier looks like an email
if t.userIdentifierClaim == "email" && strings.Contains(userIdentifier, "@") {
return t.isAllowedDomain(userIdentifier)
}
// For non-email identifiers with allowedUserDomains configured, log a warning
if len(t.allowedUserDomains) > 0 && t.userIdentifierClaim != "email" {
t.logger.Debugf("AllowedUserDomains is configured but userIdentifierClaim is '%s', not 'email'. Domain validation skipped for: %s",
t.userIdentifierClaim, userIdentifier)
}
// User not found in allowedUsers list
if len(t.allowedUsers) > 0 {
t.logger.Debugf("User identifier %s is not in the allowed users list", userIdentifier)
}
return false
}
// isAllowedDomain checks if an email address is authorized based on domain or user whitelist.
// It validates against both allowed user domains and specific allowed users.
// Parameters:
// - email: The email address to validate.
//
// Returns:
// - true if the email is authorized (domain or user allowed), false if not authorized
// or if the email format is invalid.
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 {
return true
}
if len(t.allowedUsers) > 0 {
_, userAllowed := t.allowedUsers[strings.ToLower(email)]
if userAllowed {
t.logger.Debugf("Email %s is explicitly allowed in allowedUsers", email)
return true
}
}
if len(t.allowedUserDomains) > 0 {
parts := strings.Split(email, "@")
if len(parts) != 2 {
t.logger.Errorf("Invalid email format encountered: %s", email)
return false
}
domain := parts[1]
_, domainAllowed := t.allowedUserDomains[domain]
if domainAllowed {
t.logger.Debugf("Email domain %s is allowed", domain)
return true
} else {
t.logger.Debugf("Email domain %s is NOT allowed. Allowed domains: %v",
domain, keysFromMap(t.allowedUserDomains))
}
} else if len(t.allowedUsers) > 0 {
t.logger.Debugf("Email %s is not in the allowed users list: %v",
email, keysFromMap(t.allowedUsers))
}
return false
}
// keysFromMap extracts string keys from a map for logging purposes.
// Helper function to get keys from a map for logging.
// Parameters:
// - m: The map to extract keys from.
//
// Returns:
// - A slice of string keys.
func keysFromMap(m map[string]struct{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// sendErrorResponse sends an appropriate error response based on the request's Accept header.
// It sends JSON responses for clients that accept JSON, otherwise sends HTML error pages.
// Parameters:
// - rw: The HTTP response writer.
// - req: The HTTP request (used to check Accept header).
// - message: The error message to display.
// - code: The HTTP status code to set for the response.
func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "application/json") {
t.logger.Debugf("Sending JSON error response (code %d): %s", code, message)
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(code)
_ = json.NewEncoder(rw).Encode(map[string]interface{}{
"error": http.StatusText(code),
"error_description": message,
"status_code": code,
}) // Safe to ignore: error response write
return
}
t.logger.Debugf("Sending HTML error response (code %d): %s", code, message)
returnURL := "/"
// Escape message to prevent XSS attacks
escapedMessage := html.EscapeString(message)
htmlBody := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<title>Authentication Error</title>
<style>
body { font-family: sans-serif; padding: 20px; background-color: #f8f9fa; color: #343a40; }
h1 { color: #dc3545; }
a { color: #007bff; text-decoration: none; }
a:hover { text-decoration: underline; }
.container { max-width: 600px; margin: auto; background: #fff; padding: 20px; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
</style>
</head>
<body>
<div class="container">
<h1>Authentication Error</h1>
<p>%s</p>
<p><a href="%s">Return to application</a></p>
</div>
</body>
</html>`, escapedMessage, returnURL)
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.WriteHeader(code)
_, _ = rw.Write([]byte(htmlBody)) // Safe to ignore: error response write
}
// Close gracefully shuts down the TraefikOidc middleware instance.
// It cancels contexts, stops background goroutines, closes HTTP connections,
// cleans up caches, and releases all resources. Safe to call multiple times.
// Returns:
// - An error if shutdown times out or resource cleanup fails.
func (t *TraefikOidc) Close() error {
var closeErr error
t.shutdownOnce.Do(func() {
t.safeLogDebug("Closing TraefikOidc plugin instance")
// Get resource manager for cleanup
rm := GetResourceManager()
// Stop singleton tasks related to this instance
_ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup
// Stop metadata refresh task using same hash-based name as startMetadataRefresh
if t.providerURL != "" {
hash := sha256.Sum256([]byte(t.providerURL))
taskName := "singleton-metadata-refresh-" + hex.EncodeToString(hash[:])[0:6]
_ = rm.StopBackgroundTask(taskName) // Safe to ignore: best effort cleanup
}
// Remove reference for this instance
rm.RemoveReference(t.name)
if t.cancelFunc != nil {
t.cancelFunc()
t.safeLogDebug("Context cancellation signaled to all goroutines")
}
// Clean up legacy stop channels if they exist
if t.tokenCleanupStopChan != nil {
close(t.tokenCleanupStopChan)
t.safeLogDebug("tokenCleanupStopChan closed")
}
if t.metadataRefreshStopChan != nil {
close(t.metadataRefreshStopChan)
t.safeLogDebug("metadataRefreshStopChan closed")
}
if t.refreshCoordinator != nil {
t.refreshCoordinator.Shutdown()
t.safeLogDebug("refreshCoordinator shut down")
}
if t.goroutineWG != nil {
done := make(chan struct{})
go func() {
t.goroutineWG.Wait()
close(done)
}()
select {
case <-done:
t.safeLogDebug("All background goroutines stopped gracefully")
case <-time.After(10 * time.Second):
t.safeLogError("Timeout waiting for background goroutines to stop")
}
} else {
t.safeLogDebug("No goroutineWG to wait for (likely in test)")
}
if t.httpClient != nil {
if transport, ok := t.httpClient.Transport.(*http.Transport); ok {
transport.CloseIdleConnections()
t.safeLogDebug("HTTP client idle connections closed")
}
}
if t.tokenHTTPClient != nil {
if transport, ok := t.tokenHTTPClient.Transport.(*http.Transport); ok {
transport.CloseIdleConnections()
t.safeLogDebug("Token HTTP client idle connections closed")
}
if t.tokenHTTPClient.Transport != t.httpClient.Transport {
if transport, ok := t.tokenHTTPClient.Transport.(*http.Transport); ok {
transport.CloseIdleConnections()
t.safeLogDebug("Token HTTP client transport closed (separate from main)")
}
}
}
if t.tokenBlacklist != nil {
t.tokenBlacklist.Close()
t.safeLogDebug("tokenBlacklist closed")
}
if t.metadataCache != nil {
t.metadataCache.Close()
t.safeLogDebug("metadataCache closed")
}
if t.tokenCache != nil {
t.tokenCache.Close()
t.safeLogDebug("tokenCache closed")
}
if t.jwkCache != nil {
t.jwkCache.Close()
t.safeLogDebug("t.jwkCache.Close() called as per original instruction.")
}
// Shutdown session manager and its background cleanup routines
if t.sessionManager != nil {
if err := t.sessionManager.Shutdown(); err != nil {
t.safeLogErrorf("Error shutting down session manager: %v", err)
} else {
t.safeLogDebug("sessionManager shutdown completed")
}
}
// Clean up error recovery manager
if t.errorRecoveryManager != nil && t.errorRecoveryManager.gracefulDegradation != nil {
t.errorRecoveryManager.gracefulDegradation.Close()
t.safeLogDebug("Error recovery manager graceful degradation closed")
}
// Stop all global background tasks
taskRegistry := GetGlobalTaskRegistry()
taskRegistry.StopAllTasks()
t.safeLogDebug("All global background tasks stopped")
// Note: Centralized pool in internal/pool is singleton-managed and doesn't require explicit cleanup
t.safeLogDebug("Memory pools managed by singleton pattern")
// Force garbage collection to help with memory cleanup after shutdown
runtime.GC()
t.safeLogDebug("Forced garbage collection after shutdown")
t.safeLogDebug("TraefikOidc plugin instance closed successfully.")
})
return closeErr
}