mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
17e3f8ef62
Two related lock-free snapshot refactors addressing the remaining
post-v1.0.16 code-review findings.
1. refreshAttemptTracker: per-field atomic.Load/Store -> atomic.Value
snapshot of *attemptState (refresh_coordinator.go).
Previously each tracker held five independently-atomic fields. The
cooldown-exit reset wrote cooldownEndNano = 0 first, then separately
stored attempts = 1 and windowStartNano = now. A concurrent
isInCooldown call could observe cooldownEndNano = 0 (reset just
completed) with attempts still at MaxRefreshAttempts, immediately
triggering a fresh cooldown — a benign double-trigger race that
nonetheless meant the state machine had observable intermediate
states.
New design: state is a *attemptState (immutable) published via
atomic.Value. All transitions (record/success/failure/window-reset/
cooldown-enter/cooldown-exit) go through mutateState, which runs a
CAS loop: load current snapshot -> construct fresh snapshot ->
CompareAndSwap. Either the entire new state publishes or none of
it does — no intermediate visibility, no cross-field race.
Under Yaegi this collapses 3-5 per-field atomic dispatches into one
atomic.Value.Load on the read path. Write paths pay an extra
allocation for the new snapshot but avoid the cross-field hazard.
2. MetadataSnapshot: hot-path readers use atomic.Value instead of
metadataMu.RLock (middleware.go, types.go, main.go, utilities.go).
middleware.ServeHTTP previously took metadataMu.RLock on every
non-bypass request to read the single field issuerURL. Under Yaegi
each RLock acquisition costs 1-5ms of interpreter dispatch.
updateMetadataEndpoints now also publishes an immutable
*MetadataSnapshot via atomic.Value; the hot-path reader loads it
in one op via t.metadataSnap(). Falls back to the legacy
metadataMu.RLock pattern when the snapshot is unpublished (some
test setups initialize the struct fields directly without going
through updateMetadataEndpoints).
Less-frequent callers (helpers, logout, token_introspection) still
take metadataMu.RLock and are unchanged. The snapshot strictly
subsets the metadataMu-protected fields, so those readers see
identical data.
Note on atomic.Pointer[T]: this would have been the cleaner type but
yaegi v0.16.1's stdlib (used by traefik:v3.7.1) exposes only the
legacy unsafe.Pointer-based atomic primitives — no generic Pointer[T].
atomic.Value provides the same semantics via interface{} + type assert.
All tests pass with -race; golangci-lint clean.
357 lines
11 KiB
Go
357 lines
11 KiB
Go
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
|
// This file contains utility/helper methods extracted from main.go for better code organization.
|
|
package traefikoidc
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"html"
|
|
"net/http"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// metadataSnap returns the most recently published *MetadataSnapshot, or nil
|
|
// if metadata has not yet been resolved. Single atomic.Value.Load — the hot
|
|
// ServeHTTP path uses this instead of acquiring metadataMu.RLock, which under
|
|
// Yaegi pays 1-5ms of interpreter-dispatch overhead per acquisition.
|
|
func (t *TraefikOidc) metadataSnap() *MetadataSnapshot {
|
|
v := t.metadataSnapshot.Load()
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
s, _ := v.(*MetadataSnapshot)
|
|
return s
|
|
}
|
|
|
|
// safeLogDebug provides nil-safe logging for debug messages
|
|
func (t *TraefikOidc) safeLogDebug(msg string) {
|
|
if t.logger != nil {
|
|
t.logger.Debug("%s", msg)
|
|
}
|
|
}
|
|
|
|
// safeLogDebugf provides nil-safe logging for formatted debug messages
|
|
func (t *TraefikOidc) safeLogDebugf(format string, args ...interface{}) {
|
|
if t.logger != nil {
|
|
t.logger.Debugf(format, args...)
|
|
}
|
|
}
|
|
|
|
// safeLogError provides nil-safe logging for error messages
|
|
func (t *TraefikOidc) safeLogError(msg string) {
|
|
if t.logger != nil {
|
|
t.logger.Error("%s", msg)
|
|
}
|
|
}
|
|
|
|
// safeLogErrorf provides nil-safe logging for formatted error messages
|
|
func (t *TraefikOidc) safeLogErrorf(format string, args ...interface{}) {
|
|
if t.logger != nil {
|
|
t.logger.Errorf(format, args...)
|
|
}
|
|
}
|
|
|
|
// safeLogInfo provides nil-safe logging for info messages
|
|
func (t *TraefikOidc) safeLogInfo(msg string) {
|
|
if t.logger != nil {
|
|
t.logger.Info("%s", msg)
|
|
}
|
|
}
|
|
|
|
// isAllowedUser checks if a user identifier is authorized based on the configured user identifier claim.
|
|
// When using email as the identifier (default), it validates against allowedUsers and allowedUserDomains.
|
|
// When using non-email identifiers (sub, oid, upn, etc.), it only validates against allowedUsers
|
|
// since domain-based validation doesn't apply to non-email identifiers.
|
|
//
|
|
// Parameters:
|
|
// - userIdentifier: The user identifier to validate (email, sub, oid, upn, etc.).
|
|
//
|
|
// Returns:
|
|
// - true if the user is authorized, false otherwise.
|
|
func (t *TraefikOidc) isAllowedUser(userIdentifier string) bool {
|
|
// If no restrictions are configured, allow all authenticated users
|
|
if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 {
|
|
return true
|
|
}
|
|
|
|
// Check if user is explicitly allowed
|
|
if len(t.allowedUsers) > 0 {
|
|
_, userAllowed := t.allowedUsers[strings.ToLower(userIdentifier)]
|
|
if userAllowed {
|
|
t.logger.Debugf("User identifier %s is explicitly allowed in allowedUsers", userIdentifier)
|
|
return true
|
|
}
|
|
}
|
|
|
|
// For email-based identifiers, also check domain restrictions
|
|
// Only apply domain validation if using email as identifier AND identifier looks like an email
|
|
if t.userIdentifierClaim == "email" && strings.Contains(userIdentifier, "@") {
|
|
return t.isAllowedDomain(userIdentifier)
|
|
}
|
|
|
|
// For non-email identifiers with allowedUserDomains configured, log a warning
|
|
if len(t.allowedUserDomains) > 0 && t.userIdentifierClaim != "email" {
|
|
t.logger.Debugf("AllowedUserDomains is configured but userIdentifierClaim is '%s', not 'email'. Domain validation skipped for: %s",
|
|
t.userIdentifierClaim, userIdentifier)
|
|
}
|
|
|
|
// User not found in allowedUsers list
|
|
if len(t.allowedUsers) > 0 {
|
|
t.logger.Debugf("User identifier %s is not in the allowed users list", userIdentifier)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// isAllowedDomain checks if an email address is authorized based on domain or user whitelist.
|
|
// It validates against both allowed user domains and specific allowed users.
|
|
// Parameters:
|
|
// - email: The email address to validate.
|
|
//
|
|
// Returns:
|
|
// - true if the email is authorized (domain or user allowed), false if not authorized
|
|
// or if the email format is invalid.
|
|
func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
|
if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 {
|
|
return true
|
|
}
|
|
|
|
if len(t.allowedUsers) > 0 {
|
|
_, userAllowed := t.allowedUsers[strings.ToLower(email)]
|
|
if userAllowed {
|
|
t.logger.Debugf("Email %s is explicitly allowed in allowedUsers", email)
|
|
return true
|
|
}
|
|
}
|
|
|
|
if len(t.allowedUserDomains) > 0 {
|
|
parts := strings.Split(email, "@")
|
|
if len(parts) != 2 {
|
|
t.logger.Errorf("Invalid email format encountered: %s", email)
|
|
return false
|
|
}
|
|
|
|
domain := parts[1]
|
|
_, domainAllowed := t.allowedUserDomains[domain]
|
|
|
|
if domainAllowed {
|
|
t.logger.Debugf("Email domain %s is allowed", domain)
|
|
return true
|
|
} else {
|
|
t.logger.Debugf("Email domain %s is NOT allowed. Allowed domains: %v",
|
|
domain, keysFromMap(t.allowedUserDomains))
|
|
}
|
|
} else if len(t.allowedUsers) > 0 {
|
|
t.logger.Debugf("Email %s is not in the allowed users list: %v",
|
|
email, keysFromMap(t.allowedUsers))
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// keysFromMap extracts string keys from a map for logging purposes.
|
|
// Helper function to get keys from a map for logging.
|
|
// Parameters:
|
|
// - m: The map to extract keys from.
|
|
//
|
|
// Returns:
|
|
// - A slice of string keys.
|
|
func keysFromMap(m map[string]struct{}) []string {
|
|
keys := make([]string, 0, len(m))
|
|
for k := range m {
|
|
keys = append(keys, k)
|
|
}
|
|
return keys
|
|
}
|
|
|
|
// sendErrorResponse sends an appropriate error response based on the request's Accept header.
|
|
// It sends JSON responses for clients that accept JSON, otherwise sends HTML error pages.
|
|
// Parameters:
|
|
// - rw: The HTTP response writer.
|
|
// - req: The HTTP request (used to check Accept header).
|
|
// - message: The error message to display.
|
|
// - code: The HTTP status code to set for the response.
|
|
func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
|
|
acceptHeader := req.Header.Get("Accept")
|
|
|
|
if strings.Contains(acceptHeader, "application/json") {
|
|
t.logger.Debugf("Sending JSON error response (code %d): %s", code, message)
|
|
rw.Header().Set("Content-Type", "application/json")
|
|
rw.WriteHeader(code)
|
|
_ = json.NewEncoder(rw).Encode(map[string]interface{}{
|
|
"error": http.StatusText(code),
|
|
"error_description": message,
|
|
"status_code": code,
|
|
}) // Safe to ignore: error response write
|
|
return
|
|
}
|
|
|
|
t.logger.Debugf("Sending HTML error response (code %d): %s", code, message)
|
|
|
|
returnURL := "/"
|
|
// Escape message to prevent XSS attacks
|
|
escapedMessage := html.EscapeString(message)
|
|
|
|
htmlBody := fmt.Sprintf(`
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Authentication Error</title>
|
|
<style>
|
|
body { font-family: sans-serif; padding: 20px; background-color: #f8f9fa; color: #343a40; }
|
|
h1 { color: #dc3545; }
|
|
a { color: #007bff; text-decoration: none; }
|
|
a:hover { text-decoration: underline; }
|
|
.container { max-width: 600px; margin: auto; background: #fff; padding: 20px; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div class="container">
|
|
<h1>Authentication Error</h1>
|
|
<p>%s</p>
|
|
<p><a href="%s">Return to application</a></p>
|
|
</div>
|
|
</body>
|
|
</html>`, escapedMessage, returnURL)
|
|
|
|
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
rw.WriteHeader(code)
|
|
_, _ = rw.Write([]byte(htmlBody)) // Safe to ignore: error response write
|
|
}
|
|
|
|
// Close gracefully shuts down the TraefikOidc middleware instance.
|
|
// It cancels contexts, stops background goroutines, closes HTTP connections,
|
|
// cleans up caches, and releases all resources. Safe to call multiple times.
|
|
// Returns:
|
|
// - An error if shutdown times out or resource cleanup fails.
|
|
func (t *TraefikOidc) Close() error {
|
|
var closeErr error
|
|
t.shutdownOnce.Do(func() {
|
|
t.safeLogDebug("Closing TraefikOidc plugin instance")
|
|
|
|
// Get resource manager for cleanup
|
|
rm := GetResourceManager()
|
|
|
|
// Stop singleton tasks related to this instance
|
|
_ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup
|
|
// Stop metadata refresh task using same hash-based name as startMetadataRefresh
|
|
if t.providerURL != "" {
|
|
hash := sha256.Sum256([]byte(t.providerURL))
|
|
taskName := "singleton-metadata-refresh-" + hex.EncodeToString(hash[:])[0:6]
|
|
_ = rm.StopBackgroundTask(taskName) // Safe to ignore: best effort cleanup
|
|
}
|
|
|
|
// Remove reference for this instance
|
|
rm.RemoveReference(t.name)
|
|
|
|
if t.cancelFunc != nil {
|
|
t.cancelFunc()
|
|
t.safeLogDebug("Context cancellation signaled to all goroutines")
|
|
}
|
|
|
|
// Clean up legacy stop channels if they exist
|
|
if t.tokenCleanupStopChan != nil {
|
|
close(t.tokenCleanupStopChan)
|
|
t.safeLogDebug("tokenCleanupStopChan closed")
|
|
}
|
|
if t.metadataRefreshStopChan != nil {
|
|
close(t.metadataRefreshStopChan)
|
|
t.safeLogDebug("metadataRefreshStopChan closed")
|
|
}
|
|
|
|
if t.refreshCoordinator != nil {
|
|
t.refreshCoordinator.Shutdown()
|
|
t.safeLogDebug("refreshCoordinator shut down")
|
|
}
|
|
|
|
if t.goroutineWG != nil {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
t.goroutineWG.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
t.safeLogDebug("All background goroutines stopped gracefully")
|
|
case <-time.After(10 * time.Second):
|
|
t.safeLogError("Timeout waiting for background goroutines to stop")
|
|
}
|
|
} else {
|
|
t.safeLogDebug("No goroutineWG to wait for (likely in test)")
|
|
}
|
|
|
|
if t.httpClient != nil {
|
|
if transport, ok := t.httpClient.Transport.(*http.Transport); ok {
|
|
transport.CloseIdleConnections()
|
|
t.safeLogDebug("HTTP client idle connections closed")
|
|
}
|
|
}
|
|
|
|
if t.tokenHTTPClient != nil {
|
|
if transport, ok := t.tokenHTTPClient.Transport.(*http.Transport); ok {
|
|
transport.CloseIdleConnections()
|
|
t.safeLogDebug("Token HTTP client idle connections closed")
|
|
}
|
|
if t.tokenHTTPClient.Transport != t.httpClient.Transport {
|
|
if transport, ok := t.tokenHTTPClient.Transport.(*http.Transport); ok {
|
|
transport.CloseIdleConnections()
|
|
t.safeLogDebug("Token HTTP client transport closed (separate from main)")
|
|
}
|
|
}
|
|
}
|
|
|
|
if t.tokenBlacklist != nil {
|
|
t.tokenBlacklist.Close()
|
|
t.safeLogDebug("tokenBlacklist closed")
|
|
}
|
|
if t.metadataCache != nil {
|
|
t.metadataCache.Close()
|
|
t.safeLogDebug("metadataCache closed")
|
|
}
|
|
if t.tokenCache != nil {
|
|
t.tokenCache.Close()
|
|
t.safeLogDebug("tokenCache closed")
|
|
}
|
|
|
|
if t.jwkCache != nil {
|
|
t.jwkCache.Close()
|
|
t.safeLogDebug("t.jwkCache.Close() called as per original instruction.")
|
|
}
|
|
|
|
// Shutdown session manager and its background cleanup routines
|
|
if t.sessionManager != nil {
|
|
if err := t.sessionManager.Shutdown(); err != nil {
|
|
t.safeLogErrorf("Error shutting down session manager: %v", err)
|
|
} else {
|
|
t.safeLogDebug("sessionManager shutdown completed")
|
|
}
|
|
}
|
|
|
|
// Clean up error recovery manager
|
|
if t.errorRecoveryManager != nil && t.errorRecoveryManager.gracefulDegradation != nil {
|
|
t.errorRecoveryManager.gracefulDegradation.Close()
|
|
t.safeLogDebug("Error recovery manager graceful degradation closed")
|
|
}
|
|
|
|
// Stop all global background tasks
|
|
taskRegistry := GetGlobalTaskRegistry()
|
|
taskRegistry.StopAllTasks()
|
|
t.safeLogDebug("All global background tasks stopped")
|
|
|
|
// Note: Centralized pool in internal/pool is singleton-managed and doesn't require explicit cleanup
|
|
t.safeLogDebug("Memory pools managed by singleton pattern")
|
|
|
|
// Force garbage collection to help with memory cleanup after shutdown
|
|
runtime.GC()
|
|
t.safeLogDebug("Forced garbage collection after shutdown")
|
|
|
|
t.safeLogDebug("TraefikOidc plugin instance closed successfully.")
|
|
})
|
|
return closeErr
|
|
}
|