mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
1b49e133da
* Fix bug affecting Azure OIDC authentication ( and most likely others ) * Fixes issue #51 * Ensure that appended roles are unique. Update the documentation. * Improvements targetting possible memory usage spikes. * Additional fixes and cleanup * Refactoring code to fix the issues identified by the users. * Modernize run * Fieldalignment * Multiple changes to improve performance and reduce complexity. - Optimise the errors and recovery. - Deduplicate code in metadata cache. - Remove unused performance monitoring code. - Simplify session management and settings handling. * Fix claims issue. * Add ability to overwrite the default scopes in the settings file * Well.. that escalated quickly. Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ). * Bugfix #51: Ensures that user provided scopes overrides work. * fixup! Bugfix #51: Ensures that user provided scopes overrides work. * fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work. * Abstract the provider logic into a separate package. * Additional micro fixes and cleanups. * Simplify all the things. * fixup! Simplify all the things. * fixup! fixup! Simplify all the things. * fixup! fixup! fixup! Simplify all the things. * fixup! fixup! fixup! fixup! Simplify all the things. * ... * Cleanup tests. * fixup! Cleanup tests. * fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! fixup! Cleanup tests. * Issue #53: Fix CSRF token handling in reverse proxy 1. ✅ HTTPS Detection Fixed (session.go:723) - Now uses X-Forwarded-Proto header instead of r.URL.Scheme - Properly detects HTTPS in reverse proxy environments 2. ✅ SameSite Cookie Attribute Fixed - Removed automatic SameSiteStrictMode for HTTPS (would break OAuth) - Keeps SameSiteLaxMode to allow OAuth callbacks from external domains - Only uses Strict for AJAX requests which don't involve OAuth redirects 3. ✅ Cookie Domain Handling Fixed - Now respects X-Forwarded-Host header for cookie domain - Ensures cookies are set for the public domain, not internal proxy domain 4. ✅ EnhanceSessionSecurity Properly Integrated - Function is now actually called during session save - Applies security enhancements without breaking OAuth flow Why Issue #53 Failed Before: 1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back) 2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail 3. Cookie domain might have been wrong (internal vs public domain) Why It Works Now: 1. Cookies are properly marked Secure for HTTPS 2. Uses SameSite=Lax to allow OAuth provider callbacks 3. Cookie domain uses public domain from X-Forwarded-Host 4. CSRF token persists through the entire OAuth flow * Next set of enhancements together with memory usage improvements. * Memory leak fixes and optimisations. * CSRF and Cookie Domain fixes * fixup! CSRF and Cookie Domain fixes * Metadata cache leak fix + profiling * fixup! Metadata cache leak fix + profiling * Memory leaks hunting, part 1337. * Further pursue of perfection. * fixup! Further pursue of perfection. * fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * Clear race conditions * fixup! Clear race conditions * Weekend fun with memory leaks * Splitting code into multiple files with reasonable testing coverage. ``` ok github.com/lukaszraczylo/traefikoidc 117.017s coverage: 72.6% of statements ok github.com/lukaszraczylo/traefikoidc/auth 0.505s coverage: 87.1% of statements ok github.com/lukaszraczylo/traefikoidc/circuit_breaker 0.283s coverage: 99.0% of statements github.com/lukaszraczylo/traefikoidc/config coverage: 0.0% of statements ok github.com/lukaszraczylo/traefikoidc/handlers 0.349s coverage: 98.2% of statements ok github.com/lukaszraczylo/traefikoidc/internal/providers (cached) coverage: 94.3% of statements ok github.com/lukaszraczylo/traefikoidc/middleware 0.808s coverage: 78.0% of statements ok github.com/lukaszraczylo/traefikoidc/recovery 0.653s coverage: 100.0% of statements ok github.com/lukaszraczylo/traefikoidc/session/chunking (cached) coverage: 87.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/core (cached) coverage: 85.6% of statements ok github.com/lukaszraczylo/traefikoidc/session/crypto (cached) coverage: 81.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/storage (cached) coverage: 93.5% of statements ok github.com/lukaszraczylo/traefikoidc/session/validators (cached) coverage: 98.8% of statements ```` * fixup! Splitting code into multiple files with reasonable testing coverage. * fixup! fixup! Splitting code into multiple files with reasonable testing coverage. * Weekend fun with further optimisations. * fixup! Weekend fun with further optimisations. * fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * Pre-release cleanup. * Enhance test coverage. * fixup! Enhance test coverage. * fixup! fixup! Enhance test coverage. * fixup! fixup! fixup! Enhance test coverage.
2572 lines
86 KiB
Go
2572 lines
86 KiB
Go
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
|
// It supports multiple OIDC providers including Google, Azure AD, and generic OIDC providers
|
|
// with features like token refresh, session management, and provider-specific optimizations.
|
|
package traefikoidc
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"text/template"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// Deprecated: Use CreateDefaultHTTPClient from http_client_factory.go instead
|
|
// createDefaultHTTPClient is kept for backward compatibility
|
|
func createDefaultHTTPClient() *http.Client {
|
|
return CreateDefaultHTTPClient()
|
|
}
|
|
|
|
const (
|
|
ConstSessionTimeout = 86400
|
|
)
|
|
|
|
// isTestMode detects if the code is running in a test environment.
|
|
// It checks various indicators including environment variables, command-line arguments,
|
|
// and runtime compiler information to determine test context.
|
|
// This helps suppress diagnostic logs during testing to keep test output clean.
|
|
// Returns:
|
|
// - true if running in test mode, false otherwise.
|
|
func isTestMode() bool {
|
|
if os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS") == "1" {
|
|
return true
|
|
}
|
|
|
|
if strings.Contains(os.Args[0], ".test") ||
|
|
strings.Contains(os.Args[0], "go_build_") ||
|
|
os.Getenv("GO_TEST") == "1" ||
|
|
runtime.Compiler == "yaegi" {
|
|
return true
|
|
}
|
|
|
|
for _, arg := range os.Args {
|
|
if strings.Contains(arg, "-test") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
if runtime.Compiler == "gc" {
|
|
progName := os.Args[0]
|
|
if strings.Contains(progName, "test") ||
|
|
strings.HasSuffix(progName, ".test") ||
|
|
strings.Contains(progName, "__debug_bin") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Only use runtime stack check as fallback when no explicit test conditions are being controlled
|
|
// This prevents interference with unit tests that want to test false conditions
|
|
// Skip runtime stack check if explicitly disabled for testing
|
|
if os.Getenv("DISABLE_RUNTIME_STACK_CHECK") != "1" &&
|
|
os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS") == "" &&
|
|
os.Getenv("GO_TEST") == "" {
|
|
// Check runtime stack for test functions only as last resort
|
|
buf := make([]byte, 2048)
|
|
n := runtime.Stack(buf, false)
|
|
stack := string(buf[:n])
|
|
if strings.Contains(stack, "testing.tRunner") ||
|
|
strings.Contains(stack, "testing.(*T)") ||
|
|
strings.Contains(stack, ".test.") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// defaultExcludedURLs are the paths that are excluded from authentication
|
|
var defaultExcludedURLs = map[string]struct{}{
|
|
"/favicon": {},
|
|
}
|
|
|
|
// VerifyToken verifies the validity of an ID token or access token.
|
|
// It performs comprehensive validation including format checks, blacklist verification,
|
|
// signature validation using JWKs, and standard claims validation. It also caches
|
|
// successfully verified tokens to avoid repeated verification.
|
|
// Parameters:
|
|
// - token: The JWT token string to verify.
|
|
//
|
|
// Returns:
|
|
// - An error if verification fails (e.g., blacklisted token, invalid format,
|
|
// signature failure, or claims error), nil if verification succeeds.
|
|
func (t *TraefikOidc) VerifyToken(token string) error {
|
|
if token == "" {
|
|
return fmt.Errorf("invalid JWT format: token is empty")
|
|
}
|
|
|
|
if strings.Count(token, ".") != 2 {
|
|
return fmt.Errorf("invalid JWT format: expected JWT with 3 parts, got %d parts", strings.Count(token, ".")+1)
|
|
}
|
|
|
|
if len(token) < 10 {
|
|
return fmt.Errorf("token too short to be valid JWT")
|
|
}
|
|
|
|
if t.tokenBlacklist != nil {
|
|
if blacklisted, exists := t.tokenBlacklist.Get(token); exists && blacklisted != nil {
|
|
return fmt.Errorf("token is blacklisted (raw string) in cache")
|
|
}
|
|
}
|
|
|
|
parsedJWT, parseErr := parseJWT(token)
|
|
if parseErr != nil {
|
|
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
|
|
}
|
|
|
|
tokenType := "UNKNOWN"
|
|
if aud, ok := parsedJWT.Claims["aud"]; ok {
|
|
if audStr, ok := aud.(string); ok && audStr == t.clientID {
|
|
tokenType = "ID_TOKEN"
|
|
}
|
|
}
|
|
if scope, ok := parsedJWT.Claims["scope"]; ok {
|
|
if _, ok := scope.(string); ok {
|
|
tokenType = "ACCESS_TOKEN"
|
|
}
|
|
}
|
|
|
|
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
|
|
if !strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
|
|
if t.tokenBlacklist != nil {
|
|
if blacklisted, exists := t.tokenBlacklist.Get(jti); exists && blacklisted != nil {
|
|
return fmt.Errorf("token replay detected (jti: %s) in cache", jti)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
|
return nil
|
|
}
|
|
|
|
if !t.limiter.Allow() {
|
|
return fmt.Errorf("rate limit exceeded")
|
|
}
|
|
|
|
jwt := parsedJWT
|
|
|
|
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
|
if !strings.Contains(err.Error(), "token has expired") {
|
|
t.safeLogErrorf("%s token verification failed: %v", tokenType, err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
t.cacheVerifiedToken(token, jwt.Claims)
|
|
|
|
if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" {
|
|
expiry := time.Now().Add(defaultBlacklistDuration)
|
|
if expClaim, expOk := jwt.Claims["exp"].(float64); expOk {
|
|
expTime := time.Unix(int64(expClaim), 0)
|
|
tokenDuration := time.Until(expTime)
|
|
if tokenDuration > defaultBlacklistDuration && tokenDuration < (24*time.Hour) {
|
|
expiry = expTime
|
|
} else if tokenDuration <= 0 {
|
|
expiry = time.Now().Add(defaultBlacklistDuration)
|
|
} else {
|
|
expiry = time.Now().Add(defaultBlacklistDuration)
|
|
}
|
|
}
|
|
|
|
if t.tokenBlacklist != nil {
|
|
t.tokenBlacklist.Set(jti, true, time.Until(expiry))
|
|
t.safeLogDebugf("Added JTI %s to blacklist cache", jti)
|
|
} else {
|
|
t.safeLogErrorf("Token blacklist not available, skipping JTI %s blacklist", jti)
|
|
}
|
|
|
|
replayCacheMu.Lock()
|
|
if replayCache == nil {
|
|
initReplayCache()
|
|
}
|
|
duration := time.Until(expiry)
|
|
if duration > 0 {
|
|
replayCache.Set(jti, true, duration)
|
|
}
|
|
replayCacheMu.Unlock()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// cacheVerifiedToken stores a successfully verified token and its claims in the cache.
|
|
// The token is cached until its expiration time to avoid repeated verification.
|
|
// Parameters:
|
|
// - token: The verified token string to cache.
|
|
// - claims: The map of claims extracted from the verified token.
|
|
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
|
|
expClaim, ok := claims["exp"].(float64)
|
|
if !ok {
|
|
t.safeLogError("Failed to cache token: invalid 'exp' claim type")
|
|
return
|
|
}
|
|
|
|
expirationTime := time.Unix(int64(expClaim), 0)
|
|
now := time.Now()
|
|
duration := expirationTime.Sub(now)
|
|
t.tokenCache.Set(token, claims, duration)
|
|
}
|
|
|
|
// VerifyJWTSignatureAndClaims verifies JWT signature using provider's public keys and validates standard claims.
|
|
// It retrieves the appropriate public key from the JWKS cache, verifies the token signature,
|
|
// and validates standard OIDC claims like issuer, audience, and expiration.
|
|
// Parameters:
|
|
// - jwt: The parsed JWT structure containing header and claims.
|
|
// - token: The raw token string for signature verification.
|
|
//
|
|
// Returns:
|
|
// - An error if verification fails (e.g., JWKS retrieval failed, no matching key,
|
|
// signature verification failed, standard claim validation failed), nil if successful.
|
|
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
|
t.safeLogDebugf("Verifying JWT signature and claims")
|
|
|
|
jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get JWKS: %w", err)
|
|
}
|
|
|
|
if !t.suppressDiagnosticLogs && jwks != nil {
|
|
t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), t.jwksURL)
|
|
}
|
|
|
|
kid, ok := jwt.Header["kid"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing key ID in token header")
|
|
}
|
|
alg, ok := jwt.Header["alg"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing algorithm in token header")
|
|
}
|
|
|
|
if !t.suppressDiagnosticLogs {
|
|
t.safeLogDebugf("DIAGNOSTIC: Looking for kid=%s, alg=%s in JWKS", kid, alg)
|
|
}
|
|
|
|
if jwks == nil {
|
|
return fmt.Errorf("JWKS is nil, cannot verify token")
|
|
}
|
|
|
|
// Find the matching key in JWKS
|
|
var matchingKey *JWK
|
|
availableKids := make([]string, 0, len(jwks.Keys))
|
|
for _, key := range jwks.Keys {
|
|
availableKids = append(availableKids, key.Kid)
|
|
if key.Kid == kid {
|
|
matchingKey = &key
|
|
break
|
|
}
|
|
}
|
|
|
|
if matchingKey == nil {
|
|
if !t.suppressDiagnosticLogs {
|
|
t.safeLogErrorf("DIAGNOSTIC: No matching key found for kid=%s. Available kids: %v", kid, availableKids)
|
|
}
|
|
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
|
}
|
|
|
|
if !t.suppressDiagnosticLogs {
|
|
t.safeLogDebugf("DIAGNOSTIC: Found matching key for kid=%s, key type: %s", kid, matchingKey.Kty)
|
|
}
|
|
|
|
publicKeyPEM, err := jwkToPEM(matchingKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
|
}
|
|
|
|
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
|
|
if !t.suppressDiagnosticLogs {
|
|
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
|
|
}
|
|
return fmt.Errorf("signature verification failed: %w", err)
|
|
}
|
|
|
|
if !t.suppressDiagnosticLogs {
|
|
t.safeLogDebugf("DIAGNOSTIC: Signature verification successful for kid=%s", kid)
|
|
}
|
|
|
|
if err := jwt.Verify(t.issuerURL, t.clientID, true); err != nil {
|
|
return fmt.Errorf("standard claim verification failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// mergeScopes combines default scopes with user-provided scopes, removing duplicates.
|
|
// Default scopes are placed first, followed by user scopes not already present.
|
|
// Parameters:
|
|
// - defaultScopes: The default scopes required by the application.
|
|
// - userScopes: Additional scopes specified by the user.
|
|
//
|
|
// Returns:
|
|
// - A slice containing merged scopes with defaults first, then user scopes, with duplicates removed.
|
|
func mergeScopes(defaultScopes, userScopes []string) []string {
|
|
if len(userScopes) == 0 {
|
|
return append([]string(nil), defaultScopes...)
|
|
}
|
|
|
|
seen := make(map[string]bool)
|
|
var result []string
|
|
|
|
for _, scope := range defaultScopes {
|
|
if !seen[scope] {
|
|
seen[scope] = true
|
|
result = append(result, scope)
|
|
}
|
|
}
|
|
|
|
for _, scope := range userScopes {
|
|
if !seen[scope] {
|
|
seen[scope] = true
|
|
result = append(result, scope)
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// New creates a new TraefikOidc middleware instance.
|
|
// It initializes all components including caches, HTTP clients, session management,
|
|
// templates, and starts background processes for metadata discovery.
|
|
// Parameters:
|
|
// - ctx: The context for the middleware lifecycle.
|
|
// - next: The next HTTP handler in the middleware chain.
|
|
// - config: The OIDC configuration containing provider details, client credentials, etc.
|
|
// - name: The name of the middleware instance.
|
|
//
|
|
// Returns:
|
|
// - The configured TraefikOidc handler ready to process requests.
|
|
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
|
|
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
|
return NewWithContext(ctx, config, next, name)
|
|
}
|
|
|
|
// NewWithContext creates a new TraefikOidc middleware instance with proper context handling.
|
|
// This is the preferred constructor that ensures proper goroutine lifecycle management.
|
|
func NewWithContext(ctx context.Context, config *Config, next http.Handler, name string) (*TraefikOidc, error) {
|
|
if config == nil {
|
|
config = CreateConfig()
|
|
}
|
|
|
|
if config.SessionEncryptionKey == "" {
|
|
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
}
|
|
|
|
logger := NewLogger(config.LogLevel)
|
|
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
|
|
if runtime.Compiler == "yaegi" {
|
|
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
logger.Infof("Session encryption key is too short; using default key for analyzer")
|
|
} else {
|
|
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
|
}
|
|
}
|
|
// Setup HTTP client
|
|
var httpClient *http.Client
|
|
if config.HTTPClient != nil {
|
|
httpClient = config.HTTPClient
|
|
} else {
|
|
httpClient = CreateDefaultHTTPClient()
|
|
}
|
|
goroutineWG := &sync.WaitGroup{}
|
|
cacheManager := GetGlobalCacheManager(goroutineWG)
|
|
|
|
// Use provided context instead of creating new one
|
|
var pluginCtx context.Context
|
|
var cancelFunc context.CancelFunc
|
|
if ctx != nil {
|
|
pluginCtx, cancelFunc = context.WithCancel(ctx)
|
|
} else {
|
|
pluginCtx, cancelFunc = context.WithCancel(context.Background())
|
|
}
|
|
|
|
t := &TraefikOidc{
|
|
next: next,
|
|
name: name,
|
|
goroutineWG: goroutineWG,
|
|
redirURLPath: config.CallbackURL,
|
|
logoutURLPath: func() string {
|
|
if config.LogoutURL == "" {
|
|
return config.CallbackURL + "/logout"
|
|
}
|
|
return config.LogoutURL
|
|
}(),
|
|
postLogoutRedirectURI: func() string {
|
|
if config.PostLogoutRedirectURI == "" {
|
|
return "/"
|
|
}
|
|
return config.PostLogoutRedirectURI
|
|
}(),
|
|
tokenBlacklist: cacheManager.GetSharedTokenBlacklist(),
|
|
jwkCache: cacheManager.GetSharedJWKCache(),
|
|
metadataCache: cacheManager.GetSharedMetadataCache(),
|
|
clientID: config.ClientID,
|
|
clientSecret: config.ClientSecret,
|
|
forceHTTPS: config.ForceHTTPS,
|
|
enablePKCE: config.EnablePKCE,
|
|
overrideScopes: config.OverrideScopes,
|
|
scopes: func() []string {
|
|
userProvidedScopes := deduplicateScopes(config.Scopes)
|
|
|
|
if config.OverrideScopes {
|
|
return userProvidedScopes
|
|
}
|
|
|
|
defaultSystemScopes := []string{"openid", "profile", "email"}
|
|
return deduplicateScopes(mergeScopes(defaultSystemScopes, userProvidedScopes))
|
|
}(),
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
|
tokenCache: cacheManager.GetSharedTokenCache(),
|
|
httpClient: httpClient,
|
|
tokenHTTPClient: CreateTokenHTTPClient(),
|
|
excludedURLs: createStringMap(config.ExcludedURLs),
|
|
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
|
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
|
|
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
|
initComplete: make(chan struct{}),
|
|
logger: logger,
|
|
refreshGracePeriod: func() time.Duration {
|
|
if config.RefreshGracePeriodSeconds > 0 {
|
|
return time.Duration(config.RefreshGracePeriodSeconds) * time.Second
|
|
}
|
|
return 60 * time.Second
|
|
}(),
|
|
tokenCleanupStopChan: make(chan struct{}),
|
|
metadataRefreshStopChan: make(chan struct{}),
|
|
ctx: pluginCtx,
|
|
cancelFunc: cancelFunc,
|
|
suppressDiagnosticLogs: isTestMode(),
|
|
}
|
|
|
|
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, t.logger)
|
|
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
|
|
|
|
// Initialize token resilience manager with default configuration
|
|
tokenResilienceConfig := DefaultTokenResilienceConfig()
|
|
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
|
|
|
|
t.extractClaimsFunc = extractClaims
|
|
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
|
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
|
}
|
|
|
|
for k, v := range defaultExcludedURLs {
|
|
t.excludedURLs[k] = v
|
|
}
|
|
|
|
t.tokenVerifier = t
|
|
t.jwtVerifier = t
|
|
t.tokenExchanger = t
|
|
|
|
t.headerTemplates = make(map[string]*template.Template)
|
|
|
|
funcMap := template.FuncMap{
|
|
"default": func(defaultVal interface{}, val interface{}) interface{} {
|
|
if val == nil || val == "" {
|
|
return defaultVal
|
|
}
|
|
return val
|
|
},
|
|
"get": func(m interface{}, key string) interface{} {
|
|
if mapVal, ok := m.(map[string]interface{}); ok {
|
|
if val, exists := mapVal[key]; exists {
|
|
return val
|
|
}
|
|
}
|
|
return ""
|
|
},
|
|
}
|
|
|
|
for _, header := range config.Headers {
|
|
tmpl := template.New(header.Name).Funcs(funcMap).Option("missingkey=zero")
|
|
|
|
parsedTmpl, err := tmpl.Parse(header.Value)
|
|
if err != nil {
|
|
logger.Errorf("Failed to parse header template for %s: %v", header.Name, err)
|
|
continue
|
|
}
|
|
|
|
t.headerTemplates[header.Name] = parsedTmpl
|
|
logger.Debugf("Parsed template for header %s: %s", header.Name, header.Value)
|
|
}
|
|
|
|
startReplayCacheCleanup(pluginCtx, logger)
|
|
|
|
// Start memory monitoring for leak detection and performance insights
|
|
memoryMonitor := GetGlobalMemoryMonitor()
|
|
monitorInterval := 60 * time.Second
|
|
if isTestMode() {
|
|
monitorInterval = 100 * time.Millisecond // Fast interval for tests
|
|
}
|
|
memoryMonitor.StartMonitoring(pluginCtx, monitorInterval)
|
|
logger.Debug("Started global memory monitoring")
|
|
|
|
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
|
|
|
|
t.providerURL = config.ProviderURL
|
|
|
|
// Use singleton resource manager for metadata initialization
|
|
rm := GetResourceManager()
|
|
|
|
// Add reference for this instance
|
|
rm.AddReference(name)
|
|
|
|
// Initialize metadata in a goroutine with proper tracking
|
|
if t.goroutineWG != nil {
|
|
t.goroutineWG.Add(1)
|
|
}
|
|
go func() {
|
|
defer func() {
|
|
if t.goroutineWG != nil {
|
|
t.goroutineWG.Done()
|
|
}
|
|
// Recover from panics to prevent goroutine leaks
|
|
if r := recover(); r != nil {
|
|
t.safeLogErrorf("Initialize metadata goroutine panic recovered: %v", r)
|
|
}
|
|
}()
|
|
t.initializeMetadata(config.ProviderURL)
|
|
}()
|
|
|
|
// Setup cleanup hook for when context is cancelled
|
|
if pluginCtx != nil {
|
|
go func() {
|
|
<-pluginCtx.Done()
|
|
t.Close()
|
|
}()
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
// initializeMetadata initializes OIDC provider metadata by fetching configuration.
|
|
// It retrieves the provider's .well-known/openid-configuration and updates
|
|
// internal endpoint URLs. Uses error recovery if available for resilient fetching.
|
|
// Parameters:
|
|
// - providerURL: The base URL of the OIDC provider.
|
|
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
|
t.safeLogDebug("Starting provider metadata discovery")
|
|
|
|
// Ensure initComplete is always closed, even on failure
|
|
defer func() {
|
|
select {
|
|
case <-t.initComplete:
|
|
// Already closed, do nothing
|
|
default:
|
|
close(t.initComplete)
|
|
}
|
|
}()
|
|
|
|
// Get metadata from cache or fetch it with error recovery if available
|
|
var metadata *ProviderMetadata
|
|
var err error
|
|
if t.errorRecoveryManager != nil {
|
|
metadata, err = t.metadataCache.GetMetadataWithRecovery(providerURL, t.httpClient, t.logger, t.errorRecoveryManager)
|
|
} else {
|
|
metadata, err = t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
|
}
|
|
if err != nil {
|
|
t.safeLogErrorf("Failed to get provider metadata: %v", err)
|
|
return
|
|
}
|
|
|
|
if metadata != nil {
|
|
t.safeLogDebug("Successfully initialized provider metadata")
|
|
t.updateMetadataEndpoints(metadata)
|
|
return
|
|
}
|
|
|
|
t.safeLogError("Received nil metadata during initialization")
|
|
}
|
|
|
|
// updateMetadataEndpoints updates internal endpoint URLs with discovered metadata.
|
|
// It sets the authorization URL, token URL, JWKS URL, issuer URL, revocation URL,
|
|
// and end session URL based on the provider's metadata.
|
|
// Parameters:
|
|
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
|
|
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
|
t.jwksURL = metadata.JWKSURL
|
|
t.authURL = metadata.AuthURL
|
|
t.tokenURL = metadata.TokenURL
|
|
t.issuerURL = metadata.Issuer
|
|
t.revocationURL = metadata.RevokeURL
|
|
t.endSessionURL = metadata.EndSessionURL
|
|
}
|
|
|
|
// startMetadataRefresh starts a background goroutine that periodically refreshes provider metadata.
|
|
// It runs every 2 hours and implements exponential backoff for consecutive failures.
|
|
// The refresh helps ensure endpoint URLs stay current and handles provider configuration changes.
|
|
// Parameters:
|
|
// - providerURL: The base URL of the OIDC provider, used for subsequent refresh attempts.
|
|
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
|
// Use singleton resource manager for metadata refresh
|
|
rm := GetResourceManager()
|
|
taskName := "singleton-metadata-refresh"
|
|
|
|
// Create refresh function
|
|
refreshFunc := func() {
|
|
if t.metadataCache == nil || t.httpClient == nil {
|
|
return
|
|
}
|
|
|
|
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
|
if err != nil {
|
|
t.safeLogErrorf("Failed to refresh provider metadata: %v", err)
|
|
return
|
|
}
|
|
|
|
if metadata != nil {
|
|
t.updateMetadataEndpoints(metadata)
|
|
t.safeLogDebug("Successfully refreshed provider metadata")
|
|
}
|
|
}
|
|
|
|
// Register as singleton task - will return existing if already registered
|
|
err := rm.RegisterBackgroundTask(taskName, 2*time.Hour, refreshFunc)
|
|
if err != nil {
|
|
t.logger.Errorf("Failed to register metadata refresh task: %v", err)
|
|
return
|
|
}
|
|
|
|
// Start the task if not already running
|
|
if !rm.IsTaskRunning(taskName) {
|
|
rm.StartBackgroundTask(taskName)
|
|
t.logger.Debug("Started singleton metadata refresh task")
|
|
} else {
|
|
t.logger.Debug("Metadata refresh task already running, skipping duplicate")
|
|
}
|
|
}
|
|
|
|
// ServeHTTP implements the main middleware logic for processing HTTP requests.
|
|
// It handles the complete OIDC authentication flow including:
|
|
// - Excluded URL bypass
|
|
// - Session validation and management
|
|
// - Authentication callback processing
|
|
// - Logout handling
|
|
// - Token verification and refresh
|
|
// - Header injection for authenticated requests
|
|
//
|
|
// Parameters:
|
|
// - rw: The HTTP response writer.
|
|
// - req: The incoming HTTP request.
|
|
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
if !strings.HasPrefix(req.URL.Path, "/health") {
|
|
t.firstRequestMutex.Lock()
|
|
if !t.firstRequestReceived {
|
|
t.firstRequestReceived = true
|
|
t.logger.Debug("Starting background tasks on first request")
|
|
t.startTokenCleanup()
|
|
|
|
if !t.metadataRefreshStarted && t.providerURL != "" {
|
|
t.metadataRefreshStarted = true
|
|
// Metadata refresh is handled by singleton resource manager
|
|
t.startMetadataRefresh(t.providerURL)
|
|
}
|
|
}
|
|
t.firstRequestMutex.Unlock()
|
|
}
|
|
|
|
select {
|
|
case <-t.initComplete:
|
|
if t.issuerURL == "" {
|
|
t.logger.Error("OIDC provider metadata initialization failed or incomplete")
|
|
t.sendErrorResponse(rw, req, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
case <-req.Context().Done():
|
|
t.logger.Debug("Request cancelled while waiting for OIDC initialization")
|
|
t.sendErrorResponse(rw, req, "Request cancelled", http.StatusRequestTimeout)
|
|
return
|
|
case <-time.After(30 * time.Second):
|
|
t.logger.Error("Timeout waiting for OIDC initialization")
|
|
t.sendErrorResponse(rw, req, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
if t.determineExcludedURL(req.URL.Path) {
|
|
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
|
|
t.next.ServeHTTP(rw, req)
|
|
return
|
|
}
|
|
acceptHeader := req.Header.Get("Accept")
|
|
if strings.Contains(acceptHeader, "text/event-stream") {
|
|
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
|
|
t.next.ServeHTTP(rw, req)
|
|
return
|
|
}
|
|
|
|
t.sessionManager.CleanupOldCookies(rw, req)
|
|
|
|
session, err := t.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.logger.Errorf("Error getting session: %v. Initiating authentication.", err)
|
|
cleanReq := req.Clone(req.Context())
|
|
session, _ = t.sessionManager.GetSession(cleanReq)
|
|
if session != nil {
|
|
defer session.returnToPoolSafely()
|
|
if clearErr := session.Clear(cleanReq, rw); clearErr != nil {
|
|
t.logger.Errorf("Error clearing potentially corrupted session: %v", clearErr)
|
|
}
|
|
} else {
|
|
t.logger.Error("Critical session error: Failed to get even a new session.")
|
|
t.sendErrorResponse(rw, req, "Critical session error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
scheme := t.determineScheme(req)
|
|
host := t.determineHost(req)
|
|
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
|
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
|
return
|
|
}
|
|
|
|
defer session.returnToPoolSafely()
|
|
|
|
scheme := t.determineScheme(req)
|
|
host := t.determineHost(req)
|
|
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
|
|
|
if req.URL.Path == t.logoutURLPath {
|
|
t.handleLogout(rw, req)
|
|
return
|
|
}
|
|
if req.URL.Path == t.redirURLPath {
|
|
t.handleCallback(rw, req, redirectURL)
|
|
return
|
|
}
|
|
|
|
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
|
|
|
if expired {
|
|
t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth")
|
|
t.handleExpiredToken(rw, req, session, redirectURL)
|
|
return
|
|
}
|
|
|
|
email := session.GetEmail()
|
|
// Domain restriction check removed debug output
|
|
if authenticated && email != "" {
|
|
if !t.isAllowedDomain(email) {
|
|
t.logger.Infof("User with email %s is not from an allowed domain", email)
|
|
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
|
|
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
|
return
|
|
}
|
|
}
|
|
|
|
if authenticated && !needsRefresh {
|
|
t.logger.Debug("User authenticated and token valid, proceeding to process authorized request")
|
|
if accessToken := session.GetAccessToken(); accessToken != "" {
|
|
if strings.Count(accessToken, ".") == 2 {
|
|
if err := t.verifyToken(accessToken); err != nil {
|
|
t.logger.Errorf("Access token validation failed: %v", err)
|
|
t.handleExpiredToken(rw, req, session, redirectURL)
|
|
return
|
|
}
|
|
} else {
|
|
t.logger.Debugf("Access token appears opaque, skipping JWT verification for it.")
|
|
}
|
|
}
|
|
t.processAuthorizedRequest(rw, req, session, redirectURL)
|
|
return
|
|
}
|
|
|
|
refreshTokenPresent := session.GetRefreshToken() != ""
|
|
|
|
// Check if this is an AJAX request that should receive 401 instead of redirect
|
|
isAjaxRequest := t.isAjaxRequest(req)
|
|
|
|
// Check if refresh token is likely expired (older than 6 hours)
|
|
refreshTokenExpired := refreshTokenPresent && t.isRefreshTokenExpired(session)
|
|
|
|
shouldAttemptRefresh := needsRefresh && refreshTokenPresent && !refreshTokenExpired
|
|
|
|
// If AJAX request and refresh token expired, return 401 immediately
|
|
if isAjaxRequest && refreshTokenExpired {
|
|
t.logger.Debug("AJAX request with expired refresh token, returning 401")
|
|
t.sendErrorResponse(rw, req, "Session expired", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if shouldAttemptRefresh {
|
|
idToken := session.GetIDToken()
|
|
if idToken != "" {
|
|
jwt, err := parseJWT(idToken)
|
|
if err == nil {
|
|
claims := jwt.Claims
|
|
if expClaim, ok := claims["exp"].(float64); ok {
|
|
expTime := int64(expClaim)
|
|
expTimeObj := time.Unix(expTime, 0)
|
|
refreshThreshold := time.Now().Add(t.refreshGracePeriod)
|
|
|
|
if !expTimeObj.Before(refreshThreshold) {
|
|
t.logger.Debug("Token is valid and outside grace period, skipping refresh")
|
|
t.processAuthorizedRequest(rw, req, session, redirectURL)
|
|
return
|
|
}
|
|
} else {
|
|
t.logger.Debug("Could not extract 'exp' claim for grace period check, proceeding with refresh")
|
|
}
|
|
}
|
|
}
|
|
|
|
if needsRefresh && authenticated {
|
|
t.logger.Debug("Session token needs proactive refresh, attempting refresh")
|
|
} else if needsRefresh && !authenticated {
|
|
t.logger.Debug("ID token invalid/expired, but refresh token found. Attempting refresh.")
|
|
}
|
|
|
|
refreshed := t.refreshToken(rw, req, session)
|
|
if refreshed {
|
|
email = session.GetEmail()
|
|
if email != "" && !t.isAllowedDomain(email) {
|
|
t.logger.Infof("User with refreshed token email %s is not from an allowed domain", email)
|
|
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
|
|
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
t.logger.Debug("Token refresh successful, proceeding to process authorized request")
|
|
t.processAuthorizedRequest(rw, req, session, redirectURL)
|
|
return
|
|
}
|
|
|
|
t.logger.Debug("Token refresh failed, requiring re-authentication")
|
|
if isAjaxRequest {
|
|
t.logger.Debug("AJAX request with failed token refresh, sending 401 Unauthorized")
|
|
t.sendErrorResponse(rw, req, "Token refresh failed", http.StatusUnauthorized)
|
|
} else {
|
|
t.logger.Debug("Browser request with failed token refresh, initiating re-auth")
|
|
// Reset redirect count when starting fresh auth after failed refresh to prevent redirect loops
|
|
session.ResetRedirectCount()
|
|
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
|
}
|
|
return
|
|
}
|
|
|
|
t.logger.Debugf("Initiating full OIDC authentication flow (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent)
|
|
|
|
// If AJAX request without valid authentication, return 401
|
|
if isAjaxRequest {
|
|
t.logger.Debug("AJAX request requires authentication, sending 401 Unauthorized")
|
|
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Reset redirect count when starting fresh authentication flow
|
|
session.ResetRedirectCount()
|
|
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
|
}
|
|
|
|
// processAuthorizedRequest processes requests for authenticated users.
|
|
// It extracts claims, validates roles/groups if configured, sets authentication headers,
|
|
// processes header templates, and forwards the request to the next handler.
|
|
// Domain checks should be performed before calling this method.
|
|
// Parameters:
|
|
// - rw: The HTTP response writer.
|
|
// - req: The HTTP request to process.
|
|
// - session: The user's session data containing tokens and claims.
|
|
// - redirectURL: The callback URL for re-authentication if needed.
|
|
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
|
email := session.GetEmail()
|
|
if email == "" {
|
|
t.logger.Info("No email found in session during final processing, initiating re-auth")
|
|
// Reset redirect count to prevent loops when session is invalid
|
|
session.ResetRedirectCount()
|
|
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
|
return
|
|
}
|
|
|
|
tokenForClaims := session.GetIDToken()
|
|
if tokenForClaims == "" {
|
|
tokenForClaims = session.GetAccessToken()
|
|
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
|
|
t.logger.Error("No token available but roles/groups checks are required")
|
|
// Reset redirect count to prevent loops when token is missing
|
|
session.ResetRedirectCount()
|
|
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Initialize empty slices
|
|
var groups, roles []string
|
|
|
|
if tokenForClaims != "" {
|
|
var err error
|
|
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
|
|
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
|
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
|
// Reset redirect count to prevent loops when claim extraction fails
|
|
session.ResetRedirectCount()
|
|
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
|
return
|
|
} else if err == nil {
|
|
if len(groups) > 0 {
|
|
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
|
}
|
|
if len(roles) > 0 {
|
|
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(t.allowedRolesAndGroups) > 0 {
|
|
allowed := false
|
|
for _, roleOrGroup := range append(groups, roles...) {
|
|
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
|
|
allowed = true
|
|
break
|
|
}
|
|
}
|
|
if !allowed {
|
|
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
|
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
|
|
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
|
return
|
|
}
|
|
}
|
|
|
|
req.Header.Set("X-Forwarded-User", email)
|
|
|
|
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
|
req.Header.Set("X-Auth-Request-User", email)
|
|
if idToken := session.GetIDToken(); idToken != "" {
|
|
req.Header.Set("X-Auth-Request-Token", idToken)
|
|
}
|
|
|
|
if len(t.headerTemplates) > 0 {
|
|
claims, err := t.extractClaimsFunc(session.GetIDToken())
|
|
if err != nil {
|
|
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
|
|
} else {
|
|
templateData := map[string]interface{}{
|
|
"AccessToken": session.GetAccessToken(),
|
|
"IDToken": session.GetIDToken(),
|
|
"RefreshToken": session.GetRefreshToken(),
|
|
"Claims": claims,
|
|
}
|
|
|
|
for headerName, tmpl := range t.headerTemplates {
|
|
var buf bytes.Buffer
|
|
|
|
if err := tmpl.Execute(&buf, templateData); err != nil {
|
|
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
|
|
continue
|
|
}
|
|
headerValue := buf.String()
|
|
|
|
req.Header.Set(headerName, headerValue)
|
|
|
|
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
|
}
|
|
session.MarkDirty()
|
|
t.logger.Debugf("Session marked dirty after templated header processing.")
|
|
}
|
|
}
|
|
|
|
if session.IsDirty() {
|
|
if err := session.Save(req, rw); err != nil {
|
|
t.logger.Errorf("Failed to save session after processing headers: %v", err)
|
|
}
|
|
} else {
|
|
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
|
|
}
|
|
|
|
rw.Header().Set("X-Frame-Options", "DENY")
|
|
rw.Header().Set("X-Content-Type-Options", "nosniff")
|
|
rw.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
|
|
origin := req.Header.Get("Origin")
|
|
if origin != "" {
|
|
rw.Header().Set("Access-Control-Allow-Origin", origin)
|
|
rw.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
rw.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
|
|
|
if req.Method == "OPTIONS" {
|
|
rw.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
}
|
|
|
|
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
|
|
|
t.next.ServeHTTP(rw, req)
|
|
}
|
|
|
|
// handleExpiredToken handles requests with expired or invalid tokens.
|
|
// It clears the session data and initiates a new authentication flow.
|
|
// Parameters:
|
|
// - rw: The HTTP response writer.
|
|
// - req: The HTTP request with expired token.
|
|
// - session: The session data to clear.
|
|
// - redirectURL: The callback URL to be used in the new authentication flow.
|
|
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
|
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
|
|
session.SetAuthenticated(false)
|
|
session.SetIDToken("")
|
|
session.SetAccessToken("")
|
|
session.SetRefreshToken("")
|
|
session.SetEmail("")
|
|
// Clear CSRF tokens to prevent replay attacks
|
|
session.SetCSRF("")
|
|
session.SetNonce("")
|
|
session.SetCodeVerifier("")
|
|
// Reset redirect count to prevent loops when handling expired tokens
|
|
session.ResetRedirectCount()
|
|
|
|
if err := session.Save(req, rw); err != nil {
|
|
t.logger.Errorf("Failed to save cleared session during expired token handling: %v", err)
|
|
}
|
|
|
|
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
|
}
|
|
|
|
// handleCallback processes the OIDC callback after user authentication.
|
|
// It validates state/CSRF tokens, exchanges authorization code for tokens,
|
|
// verifies the received tokens, extracts claims, and establishes the session.
|
|
// Parameters:
|
|
// - rw: The HTTP response writer.
|
|
// - req: The callback request containing authorization code and state.
|
|
// - redirectURL: The fully qualified callback URL (used in the token exchange request).
|
|
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
|
session, err := t.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.logger.Errorf("Session error during callback: %v", err)
|
|
t.sendErrorResponse(rw, req, "Session error during callback", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer session.returnToPoolSafely()
|
|
|
|
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
|
|
|
if req.URL.Query().Get("error") != "" {
|
|
errorDescription := req.URL.Query().Get("error_description")
|
|
if errorDescription == "" {
|
|
errorDescription = req.URL.Query().Get("error")
|
|
}
|
|
t.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
|
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
state := req.URL.Query().Get("state")
|
|
if state == "" {
|
|
t.logger.Error("No state in callback")
|
|
t.sendErrorResponse(rw, req, "State parameter missing in callback", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
csrfToken := session.GetCSRF()
|
|
if csrfToken == "" {
|
|
t.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
|
|
session.GetAuthenticated(), req.URL.String())
|
|
|
|
cookie, err := req.Cookie("_oidc_raczylo_m")
|
|
if err != nil {
|
|
t.logger.Errorf("Main session cookie not found in request: %v", err)
|
|
} else {
|
|
t.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
|
|
}
|
|
|
|
t.sendErrorResponse(rw, req, "CSRF token missing in session", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if state != csrfToken {
|
|
t.logger.Error("State parameter does not match CSRF token in session during callback")
|
|
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
code := req.URL.Query().Get("code")
|
|
if code == "" {
|
|
t.logger.Error("No code in callback")
|
|
t.sendErrorResponse(rw, req, "No authorization code received in callback", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
codeVerifier := session.GetCodeVerifier()
|
|
|
|
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
|
if err != nil {
|
|
t.logger.Errorf("Failed to exchange code for token during callback: %v", err)
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err = t.verifyToken(tokenResponse.IDToken); err != nil {
|
|
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
|
if err != nil {
|
|
t.logger.Errorf("Failed to extract claims during callback: %v", err)
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
nonceClaim, ok := claims["nonce"].(string)
|
|
if !ok || nonceClaim == "" {
|
|
t.logger.Error("Nonce claim missing in id_token during callback")
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
sessionNonce := session.GetNonce()
|
|
if sessionNonce == "" {
|
|
t.logger.Error("Nonce not found in session during callback")
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if nonceClaim != sessionNonce {
|
|
t.logger.Error("Nonce claim does not match session nonce during callback")
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
email, _ := claims["email"].(string)
|
|
if email == "" {
|
|
t.logger.Errorf("Email claim missing or empty in token during callback")
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if !t.isAllowedDomain(email) {
|
|
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if err := session.SetAuthenticated(true); err != nil {
|
|
t.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
|
|
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
session.SetEmail(email)
|
|
session.SetIDToken(tokenResponse.IDToken)
|
|
session.SetAccessToken(tokenResponse.AccessToken)
|
|
session.SetRefreshToken(tokenResponse.RefreshToken)
|
|
|
|
session.SetCSRF("")
|
|
session.SetNonce("")
|
|
session.SetCodeVerifier("")
|
|
|
|
session.ResetRedirectCount()
|
|
|
|
redirectPath := "/"
|
|
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
|
redirectPath = incomingPath
|
|
}
|
|
session.SetIncomingPath("")
|
|
|
|
if err := session.Save(req, rw); err != nil {
|
|
t.logger.Errorf("Failed to save session after callback: %v", err)
|
|
t.sendErrorResponse(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
t.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
|
|
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
|
}
|
|
|
|
// determineExcludedURL checks if a URL path should bypass OIDC authentication.
|
|
// It compares the request path against configured excluded URL prefixes.
|
|
// Parameters:
|
|
// - currentRequest: The request path to check.
|
|
//
|
|
// Returns:
|
|
// - true if the URL should be excluded from authentication, false otherwise.
|
|
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
|
for excludedURL := range t.excludedURLs {
|
|
if strings.HasPrefix(currentRequest, excludedURL) {
|
|
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// determineScheme determines the URL scheme for building redirect URLs.
|
|
// It checks X-Forwarded-Proto header first, then TLS presence.
|
|
// Parameters:
|
|
// - req: The HTTP request to analyze.
|
|
//
|
|
// Returns:
|
|
// - The determined scheme: "https" or "http".
|
|
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
|
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
|
return scheme
|
|
}
|
|
if req.TLS != nil {
|
|
return "https"
|
|
}
|
|
return "http"
|
|
}
|
|
|
|
// determineHost determines the host for building redirect URLs.
|
|
// It checks X-Forwarded-Host header first, then falls back to req.Host.
|
|
// Parameters:
|
|
// - req: The HTTP request to analyze.
|
|
//
|
|
// Returns:
|
|
// - The determined host string (e.g., "example.com:8080").
|
|
func (t *TraefikOidc) determineHost(req *http.Request) string {
|
|
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
|
return host
|
|
}
|
|
return req.Host
|
|
}
|
|
|
|
// isUserAuthenticated determines the authentication status and refresh requirements.
|
|
// It delegates to provider-specific validation methods that handle different token types
|
|
// and expiration behaviors.
|
|
// Parameters:
|
|
// - session: The session data containing authentication tokens.
|
|
//
|
|
// Returns:
|
|
// - authenticated (bool): True if the user has valid tokens.
|
|
// - needsRefresh (bool): True if tokens are valid but nearing expiration.
|
|
// - expired (bool): True if the session is unauthenticated, the token is missing,
|
|
// or the token verification failed for reasons other than nearing/actual expiration.
|
|
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
|
if t.isAzureProvider() {
|
|
return t.validateAzureTokens(session)
|
|
} else if t.isGoogleProvider() {
|
|
return t.validateGoogleTokens(session)
|
|
}
|
|
return t.validateStandardTokens(session)
|
|
}
|
|
|
|
// defaultInitiateAuthentication initiates the OIDC authentication flow.
|
|
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
|
|
// stores authentication state, and redirects the user to the OIDC provider.
|
|
// Parameters:
|
|
// - rw: The HTTP response writer.
|
|
// - req: The HTTP request initiating authentication.
|
|
// - session: The session data to prepare for authentication.
|
|
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
|
|
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
|
t.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
|
|
|
|
const maxRedirects = 5
|
|
redirectCount := session.GetRedirectCount()
|
|
if redirectCount >= maxRedirects {
|
|
t.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
|
|
session.ResetRedirectCount()
|
|
t.sendErrorResponse(rw, req, "Authentication failed: Too many redirects", http.StatusLoopDetected)
|
|
return
|
|
}
|
|
|
|
session.IncrementRedirectCount()
|
|
|
|
csrfToken := uuid.NewString()
|
|
nonce, err := generateNonce()
|
|
if err != nil {
|
|
t.logger.Errorf("Failed to generate nonce: %v", err)
|
|
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Generate PKCE code verifier and challenge if PKCE is enabled
|
|
var codeVerifier, codeChallenge string
|
|
if t.enablePKCE {
|
|
var err error
|
|
codeVerifier, err = generateCodeVerifier()
|
|
if err != nil {
|
|
t.logger.Errorf("Failed to generate code verifier: %v", err)
|
|
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
codeChallenge = deriveCodeChallenge(codeVerifier)
|
|
t.logger.Debugf("PKCE enabled, generated code challenge")
|
|
}
|
|
|
|
session.SetAuthenticated(false)
|
|
session.SetEmail("")
|
|
session.SetAccessToken("")
|
|
session.SetRefreshToken("")
|
|
session.SetIDToken("")
|
|
session.SetNonce("")
|
|
session.SetCodeVerifier("")
|
|
|
|
session.SetCSRF(csrfToken)
|
|
session.SetNonce(nonce)
|
|
if t.enablePKCE {
|
|
session.SetCodeVerifier(codeVerifier)
|
|
}
|
|
session.SetIncomingPath(req.URL.RequestURI())
|
|
t.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
|
|
|
|
session.MarkDirty()
|
|
|
|
if err := session.Save(req, rw); err != nil {
|
|
t.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
|
|
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
t.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
|
|
csrfToken, nonce)
|
|
|
|
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
|
t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
|
|
|
|
http.Redirect(rw, req, authURL, http.StatusFound)
|
|
}
|
|
|
|
// verifyToken is a convenience wrapper for token verification.
|
|
// It delegates to the configured token verifier interface.
|
|
// Parameters:
|
|
// - token: The token string to verify.
|
|
//
|
|
// Returns:
|
|
// - The result of calling t.tokenVerifier.VerifyToken(token).
|
|
func (t *TraefikOidc) verifyToken(token string) error {
|
|
return t.tokenVerifier.VerifyToken(token)
|
|
}
|
|
|
|
// safeLog provides nil-safe logging helpers
|
|
func (t *TraefikOidc) safeLogDebug(msg string) {
|
|
if t.logger != nil {
|
|
t.logger.Debug("%s", msg)
|
|
}
|
|
}
|
|
|
|
func (t *TraefikOidc) safeLogDebugf(format string, args ...interface{}) {
|
|
if t.logger != nil {
|
|
t.logger.Debugf(format, args...)
|
|
}
|
|
}
|
|
|
|
func (t *TraefikOidc) safeLogError(msg string) {
|
|
if t.logger != nil {
|
|
t.logger.Error("%s", msg)
|
|
}
|
|
}
|
|
|
|
func (t *TraefikOidc) safeLogErrorf(format string, args ...interface{}) {
|
|
if t.logger != nil {
|
|
t.logger.Errorf(format, args...)
|
|
}
|
|
}
|
|
|
|
func (t *TraefikOidc) safeLogInfo(msg string) {
|
|
if t.logger != nil {
|
|
t.logger.Info("%s", msg)
|
|
}
|
|
}
|
|
|
|
// buildAuthURL constructs the OIDC provider authorization URL.
|
|
// It builds the URL with all necessary parameters including client_id, scopes,
|
|
// PKCE parameters, and provider-specific parameters for Google and Azure.
|
|
// Parameters:
|
|
// - redirectURL: The callback URL for after authentication.
|
|
// - state: The CSRF token for state validation.
|
|
// - nonce: The nonce for replay protection.
|
|
// - codeChallenge: The PKCE code challenge (if PKCE is enabled).
|
|
//
|
|
// Returns:
|
|
// - The fully constructed authorization URL string.
|
|
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
|
params := url.Values{}
|
|
params.Set("client_id", t.clientID)
|
|
params.Set("response_type", "code")
|
|
params.Set("redirect_uri", redirectURL)
|
|
params.Set("state", state)
|
|
params.Set("nonce", nonce)
|
|
|
|
if t.enablePKCE && codeChallenge != "" {
|
|
params.Set("code_challenge", codeChallenge)
|
|
params.Set("code_challenge_method", "S256")
|
|
}
|
|
|
|
scopes := make([]string, len(t.scopes))
|
|
copy(scopes, t.scopes)
|
|
|
|
if t.isGoogleProvider() {
|
|
params.Set("access_type", "offline")
|
|
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
|
|
|
params.Set("prompt", "consent")
|
|
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
|
} else if t.isAzureProvider() {
|
|
params.Set("response_mode", "query")
|
|
t.logger.Debug("Azure AD provider detected, added response_mode=query")
|
|
|
|
hasOfflineAccess := false
|
|
|
|
for _, scope := range scopes {
|
|
if scope == "offline_access" {
|
|
hasOfflineAccess = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !t.overrideScopes || (t.overrideScopes && len(t.scopes) == 0) {
|
|
if !hasOfflineAccess {
|
|
scopes = append(scopes, "offline_access")
|
|
t.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", t.overrideScopes, len(t.scopes))
|
|
}
|
|
} else {
|
|
t.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(t.scopes))
|
|
}
|
|
} else {
|
|
if !t.overrideScopes || (t.overrideScopes && len(t.scopes) == 0) {
|
|
hasOfflineAccess := false
|
|
for _, scope := range scopes {
|
|
if scope == "offline_access" {
|
|
hasOfflineAccess = true
|
|
break
|
|
}
|
|
}
|
|
if !hasOfflineAccess {
|
|
scopes = append(scopes, "offline_access")
|
|
t.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", t.overrideScopes, len(t.scopes))
|
|
}
|
|
} else {
|
|
t.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(t.scopes))
|
|
}
|
|
}
|
|
|
|
if len(scopes) > 0 {
|
|
finalScopeString := strings.Join(scopes, " ")
|
|
params.Set("scope", finalScopeString)
|
|
t.logger.Debugf("TraefikOidc.buildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
|
|
}
|
|
|
|
return t.buildURLWithParams(t.authURL, params)
|
|
}
|
|
|
|
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
|
|
// It handles both relative and absolute URLs, validates URL security,
|
|
// and properly encodes query parameters.
|
|
// Parameters:
|
|
// - baseURL: The base URL to append parameters to.
|
|
// - params: The query parameters to append.
|
|
//
|
|
// Returns:
|
|
// - The fully constructed URL string with appended query parameters.
|
|
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
|
|
if baseURL != "" {
|
|
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
|
|
if err := t.validateURL(baseURL); err != nil {
|
|
t.logger.Errorf("URL validation failed for %s: %v", baseURL, err)
|
|
return ""
|
|
}
|
|
}
|
|
}
|
|
|
|
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
|
issuerURLParsed, err := url.Parse(t.issuerURL)
|
|
if err != nil {
|
|
t.logger.Errorf("Could not parse issuerURL: %s. Error: %v", t.issuerURL, err)
|
|
return ""
|
|
}
|
|
|
|
baseURLParsed, err := url.Parse(baseURL)
|
|
if err != nil {
|
|
t.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err)
|
|
return ""
|
|
}
|
|
|
|
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
|
|
|
|
if err := t.validateURL(resolvedURL.String()); err != nil {
|
|
t.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err)
|
|
return ""
|
|
}
|
|
|
|
resolvedURL.RawQuery = params.Encode()
|
|
return resolvedURL.String()
|
|
}
|
|
|
|
u, err := url.Parse(baseURL)
|
|
if err != nil {
|
|
t.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
|
|
return ""
|
|
}
|
|
|
|
if err := t.validateParsedURL(u); err != nil {
|
|
t.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err)
|
|
return ""
|
|
}
|
|
|
|
u.RawQuery = params.Encode()
|
|
return u.String()
|
|
}
|
|
|
|
// validateURL performs security validation on URLs to prevent SSRF attacks.
|
|
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
|
|
// Parameters:
|
|
// - urlStr: The URL string to validate.
|
|
//
|
|
// Returns:
|
|
// - An error if the URL is invalid or poses security risks, nil if valid.
|
|
func (t *TraefikOidc) validateURL(urlStr string) error {
|
|
if urlStr == "" {
|
|
return fmt.Errorf("empty URL")
|
|
}
|
|
|
|
u, err := url.Parse(urlStr)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL format: %w", err)
|
|
}
|
|
|
|
return t.validateParsedURL(u)
|
|
}
|
|
|
|
// validateParsedURL validates a parsed URL structure for security.
|
|
// It checks schemes, hosts, and paths to prevent malicious URLs.
|
|
// Parameters:
|
|
// - u: The parsed URL to validate.
|
|
//
|
|
// Returns:
|
|
// - An error if the URL is invalid or dangerous, nil if safe.
|
|
func (t *TraefikOidc) validateParsedURL(u *url.URL) error {
|
|
allowedSchemes := map[string]bool{
|
|
"https": true,
|
|
"http": true,
|
|
}
|
|
|
|
if !allowedSchemes[u.Scheme] {
|
|
return fmt.Errorf("disallowed URL scheme: %s", u.Scheme)
|
|
}
|
|
|
|
if u.Scheme == "http" {
|
|
t.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String())
|
|
}
|
|
|
|
if u.Host == "" {
|
|
return fmt.Errorf("missing host in URL")
|
|
}
|
|
|
|
if err := t.validateHost(u.Host); err != nil {
|
|
return fmt.Errorf("invalid host: %w", err)
|
|
}
|
|
|
|
if strings.Contains(u.Path, "..") {
|
|
return fmt.Errorf("path traversal detected in URL path")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateHost validates a hostname or IP address for security.
|
|
// It prevents access to localhost, private networks, and known metadata endpoints.
|
|
// Parameters:
|
|
// - host: The host string to validate (may include port).
|
|
//
|
|
// Returns:
|
|
// - An error if the host is dangerous or not allowed, nil if safe.
|
|
func (t *TraefikOidc) validateHost(host string) error {
|
|
hostname := host
|
|
if strings.Contains(host, ":") {
|
|
var err error
|
|
hostname, _, err = net.SplitHostPort(host)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid host format: %w", err)
|
|
}
|
|
}
|
|
|
|
ip := net.ParseIP(hostname)
|
|
if ip != nil {
|
|
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
|
return fmt.Errorf("access to private/internal IP addresses is not allowed: %s", ip.String())
|
|
}
|
|
|
|
if ip.IsUnspecified() || ip.IsMulticast() {
|
|
return fmt.Errorf("access to unspecified or multicast IP addresses is not allowed: %s", ip.String())
|
|
}
|
|
}
|
|
|
|
dangerousHosts := map[string]bool{
|
|
"localhost": true,
|
|
"127.0.0.1": true,
|
|
"::1": true,
|
|
"0.0.0.0": true,
|
|
"169.254.169.254": true,
|
|
"metadata.google.internal": true,
|
|
}
|
|
|
|
if dangerousHosts[strings.ToLower(hostname)] {
|
|
return fmt.Errorf("access to dangerous hostname is not allowed: %s", hostname)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// startTokenCleanup starts background cleanup goroutines for cache maintenance.
|
|
// It runs periodic cleanup of token cache, JWK cache, and session chunks.
|
|
// Includes panic recovery to ensure stability.
|
|
func (t *TraefikOidc) startTokenCleanup() {
|
|
if t == nil {
|
|
return
|
|
}
|
|
|
|
// Use singleton resource manager for token cleanup
|
|
rm := GetResourceManager()
|
|
taskName := "singleton-token-cleanup"
|
|
|
|
// Capture values for the cleanup function
|
|
tokenCache := t.tokenCache
|
|
jwkCache := t.jwkCache
|
|
sessionManager := t.sessionManager
|
|
logger := t.logger
|
|
|
|
cleanupInterval := 1 * time.Minute
|
|
if isTestMode() {
|
|
cleanupInterval = 50 * time.Millisecond // Fast interval for tests
|
|
}
|
|
|
|
// Create cleanup function
|
|
cleanupFunc := func() {
|
|
if logger != nil && !isTestMode() {
|
|
logger.Debug("Starting token cleanup cycle")
|
|
}
|
|
if tokenCache != nil {
|
|
tokenCache.Cleanup()
|
|
}
|
|
if jwkCache != nil {
|
|
jwkCache.Cleanup()
|
|
}
|
|
if sessionManager != nil {
|
|
sessionManager.PeriodicChunkCleanup()
|
|
if logger != nil && !isTestMode() {
|
|
logger.Debug("Running session health monitoring")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Register as singleton task - will return existing if already registered
|
|
err := rm.RegisterBackgroundTask(taskName, cleanupInterval, cleanupFunc)
|
|
if err != nil {
|
|
logger.Errorf("Failed to register token cleanup task: %v", err)
|
|
return
|
|
}
|
|
|
|
// Start the task if not already running
|
|
if !rm.IsTaskRunning(taskName) {
|
|
rm.StartBackgroundTask(taskName)
|
|
logger.Debug("Started singleton token cleanup task")
|
|
} else {
|
|
logger.Debug("Token cleanup task already running, skipping duplicate")
|
|
}
|
|
}
|
|
|
|
// RevokeToken revokes a token locally by adding it to the blacklist cache.
|
|
// It removes the token from the verification cache and adds both the token
|
|
// and its JTI (if present) to the blacklist to prevent future use.
|
|
// Parameters:
|
|
// - token: The raw token string to revoke locally.
|
|
func (t *TraefikOidc) RevokeToken(token string) {
|
|
t.tokenCache.Delete(token)
|
|
|
|
if jwt, err := parseJWT(token); err == nil {
|
|
if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" {
|
|
expiry := time.Now().Add(24 * time.Hour)
|
|
if t.tokenBlacklist != nil {
|
|
t.tokenBlacklist.Set(jti, true, time.Until(expiry))
|
|
t.logger.Debugf("Locally revoked token JTI %s (added to blacklist)", jti)
|
|
}
|
|
}
|
|
}
|
|
|
|
expiry := time.Now().Add(24 * time.Hour)
|
|
if t.tokenBlacklist != nil {
|
|
t.tokenBlacklist.Set(token, true, time.Until(expiry))
|
|
t.logger.Debugf("Locally revoked token (added to blacklist)")
|
|
}
|
|
}
|
|
|
|
// RevokeTokenWithProvider revokes a token with the OIDC provider.
|
|
// It sends a revocation request to the provider's revocation endpoint
|
|
// with proper authentication and error recovery if available.
|
|
// Parameters:
|
|
// - token: The token to revoke.
|
|
// - tokenType: The type of token ("access_token" or "refresh_token").
|
|
//
|
|
// Returns:
|
|
// - An error if the request fails or the provider returns a non-OK status.
|
|
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
|
if t.revocationURL == "" {
|
|
return fmt.Errorf("token revocation endpoint is not configured or discovered")
|
|
}
|
|
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, t.revocationURL)
|
|
|
|
data := url.Values{
|
|
"token": {token},
|
|
"token_type_hint": {tokenType},
|
|
"client_id": {t.clientID},
|
|
"client_secret": {t.clientSecret},
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(context.Background(), "POST", t.revocationURL, strings.NewReader(data.Encode()))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create token revocation request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
// Send the request with circuit breaker protection if available
|
|
var resp *http.Response
|
|
if t.errorRecoveryManager != nil {
|
|
serviceName := fmt.Sprintf("token-revocation-%s", t.issuerURL)
|
|
err = t.errorRecoveryManager.ExecuteWithRecovery(context.Background(), serviceName, func() error {
|
|
var reqErr error
|
|
resp, reqErr = t.httpClient.Do(req)
|
|
return reqErr
|
|
})
|
|
} else {
|
|
resp, err = t.httpClient.Do(req)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to send token revocation request: %w", err)
|
|
}
|
|
defer func() {
|
|
io.Copy(io.Discard, resp.Body)
|
|
resp.Body.Close()
|
|
}()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
limitReader := io.LimitReader(resp.Body, 1024*10)
|
|
body, _ := io.ReadAll(limitReader)
|
|
t.logger.Errorf("Token revocation failed with status %d: %s", resp.StatusCode, string(body))
|
|
return fmt.Errorf("token revocation failed with status %d", resp.StatusCode)
|
|
}
|
|
|
|
t.logger.Debugf("Token successfully revoked with provider")
|
|
return nil
|
|
}
|
|
|
|
// refreshToken attempts to refresh authentication tokens using the refresh token.
|
|
// It handles provider-specific refresh logic, validates new tokens, updates the session,
|
|
// and includes concurrency protection to prevent race conditions.
|
|
// Parameters:
|
|
// - rw: The HTTP response writer.
|
|
// - req: The HTTP request context.
|
|
// - session: The session data containing the refresh token.
|
|
//
|
|
// Returns:
|
|
// - true if refresh succeeded and session was updated, false if refresh failed,
|
|
// a concurrency conflict was detected, or saving the session failed.
|
|
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
|
|
session.refreshMutex.Lock()
|
|
defer session.refreshMutex.Unlock()
|
|
|
|
t.logger.Debug("Attempting to refresh token (mutex acquired)")
|
|
|
|
if !session.inUse {
|
|
t.logger.Debug("refreshToken aborted: Session no longer in use")
|
|
return false
|
|
}
|
|
|
|
initialRefreshToken := session.GetRefreshToken()
|
|
if initialRefreshToken == "" {
|
|
t.logger.Debug("No refresh token found in session")
|
|
return false
|
|
}
|
|
|
|
if t.isGoogleProvider() {
|
|
t.logger.Debug("Google OIDC provider detected for token refresh operation")
|
|
} else if t.isAzureProvider() {
|
|
t.logger.Debug("Azure AD provider detected for token refresh operation")
|
|
}
|
|
|
|
tokenPrefix := initialRefreshToken
|
|
if len(initialRefreshToken) > 10 {
|
|
tokenPrefix = initialRefreshToken[:10]
|
|
}
|
|
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
|
|
|
|
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
|
if err != nil {
|
|
errMsg := err.Error()
|
|
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
|
|
t.logger.Debug("Refresh token expired or revoked: %v", err)
|
|
// Clear all tokens and authentication state when refresh token is invalid
|
|
session.SetAuthenticated(false)
|
|
session.SetRefreshToken("")
|
|
session.SetAccessToken("")
|
|
session.SetIDToken("")
|
|
session.SetEmail("")
|
|
// Clear CSRF tokens as well to prevent any replay attacks
|
|
session.SetCSRF("")
|
|
session.SetNonce("")
|
|
session.SetCodeVerifier("")
|
|
if err = session.Save(req, rw); err != nil {
|
|
t.logger.Errorf("Failed to clear session after invalid refresh token: %v", err)
|
|
}
|
|
} else if strings.Contains(errMsg, "invalid_client") {
|
|
t.logger.Errorf("Client credentials rejected: %v - check client_id and client_secret configuration", err)
|
|
} else if t.isGoogleProvider() && strings.Contains(errMsg, "invalid_request") {
|
|
t.logger.Errorf("Google OIDC provider error: %v - check scope configuration includes 'offline_access' and prompt=consent is used during authentication", err)
|
|
} else {
|
|
t.logger.Errorf("Token refresh failed: %v", err)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
if newToken.IDToken == "" {
|
|
t.logger.Info("Provider did not return a new ID token during refresh")
|
|
return false
|
|
}
|
|
|
|
if err = t.verifyToken(newToken.IDToken); err != nil {
|
|
t.logger.Debug("Failed to verify newly obtained ID token: %v", err)
|
|
return false
|
|
}
|
|
|
|
currentRefreshToken := session.GetRefreshToken()
|
|
if initialRefreshToken != currentRefreshToken {
|
|
t.logger.Infof("refreshToken aborted: Session refresh token changed concurrently during refresh attempt.")
|
|
return false
|
|
}
|
|
|
|
t.logger.Debugf("Concurrency check passed. Updating session with new tokens.")
|
|
|
|
claims, err := t.extractClaimsFunc(newToken.IDToken)
|
|
if err != nil {
|
|
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
|
|
return false
|
|
}
|
|
email, _ := claims["email"].(string)
|
|
if email == "" {
|
|
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
|
|
return false
|
|
}
|
|
session.SetEmail(email)
|
|
|
|
// Get token expiry information for logging
|
|
var expiryTime time.Time
|
|
if expClaim, ok := claims["exp"].(float64); ok {
|
|
expiryTime = time.Unix(int64(expClaim), 0)
|
|
t.logger.Debugf("New token expires at: %v (in %v)", expiryTime, time.Until(expiryTime))
|
|
}
|
|
|
|
session.SetIDToken(newToken.IDToken)
|
|
session.SetAccessToken(newToken.AccessToken)
|
|
|
|
if newToken.RefreshToken != "" {
|
|
t.logger.Debug("Received new refresh token from provider")
|
|
session.SetRefreshToken(newToken.RefreshToken)
|
|
} else {
|
|
t.logger.Debug("Provider did not return a new refresh token, keeping the existing one")
|
|
session.SetRefreshToken(initialRefreshToken)
|
|
}
|
|
|
|
if err := session.SetAuthenticated(true); err != nil {
|
|
t.logger.Errorf("refreshToken failed: Failed to set authenticated flag: %v", err)
|
|
// Clear tokens on failure to maintain consistent state
|
|
session.SetAccessToken("")
|
|
session.SetIDToken("")
|
|
session.SetRefreshToken("")
|
|
session.SetEmail("")
|
|
return false
|
|
}
|
|
|
|
if err := session.Save(req, rw); err != nil {
|
|
t.logger.Errorf("refreshToken failed: Failed to save session after successful token refresh: %v", err)
|
|
// Reset authentication state since we couldn't persist it
|
|
session.SetAuthenticated(false)
|
|
return false
|
|
}
|
|
|
|
t.logger.Debugf("Token refresh successful and session saved")
|
|
return true
|
|
}
|
|
|
|
// isAllowedDomain checks if an email address is authorized based on domain or user whitelist.
|
|
// It validates against both allowed user domains and specific allowed users.
|
|
// Parameters:
|
|
// - email: The email address to validate.
|
|
//
|
|
// Returns:
|
|
// - true if the email is authorized (domain or user allowed), false if not authorized
|
|
// or if the email format is invalid.
|
|
func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
|
if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 {
|
|
return true
|
|
}
|
|
|
|
if len(t.allowedUsers) > 0 {
|
|
_, userAllowed := t.allowedUsers[strings.ToLower(email)]
|
|
if userAllowed {
|
|
t.logger.Debugf("Email %s is explicitly allowed in allowedUsers", email)
|
|
return true
|
|
}
|
|
}
|
|
|
|
if len(t.allowedUserDomains) > 0 {
|
|
parts := strings.Split(email, "@")
|
|
if len(parts) != 2 {
|
|
t.logger.Errorf("Invalid email format encountered: %s", email)
|
|
return false
|
|
}
|
|
|
|
domain := parts[1]
|
|
_, domainAllowed := t.allowedUserDomains[domain]
|
|
|
|
if domainAllowed {
|
|
t.logger.Debugf("Email domain %s is allowed", domain)
|
|
return true
|
|
} else {
|
|
t.logger.Debugf("Email domain %s is NOT allowed. Allowed domains: %v",
|
|
domain, keysFromMap(t.allowedUserDomains))
|
|
}
|
|
} else if len(t.allowedUsers) > 0 {
|
|
t.logger.Debugf("Email %s is not in the allowed users list: %v",
|
|
email, keysFromMap(t.allowedUsers))
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// keysFromMap extracts string keys from a map for logging purposes.
|
|
// Helper function to get keys from a map for logging.
|
|
// Parameters:
|
|
// - m: The map to extract keys from.
|
|
//
|
|
// Returns:
|
|
// - A slice of string keys.
|
|
func keysFromMap(m map[string]struct{}) []string {
|
|
keys := make([]string, 0, len(m))
|
|
for k := range m {
|
|
keys = append(keys, k)
|
|
}
|
|
return keys
|
|
}
|
|
|
|
// createCaseInsensitiveStringMap creates a map with lowercase keys for case-insensitive matching.
|
|
// This is used for case-insensitive matching of email addresses.
|
|
// Parameters:
|
|
// - items: The string items to convert to lowercase keys.
|
|
//
|
|
// Returns:
|
|
// - A map with lowercase string keys for case-insensitive lookups.
|
|
func createCaseInsensitiveStringMap(items []string) map[string]struct{} {
|
|
result := make(map[string]struct{})
|
|
for _, item := range items {
|
|
result[strings.ToLower(item)] = struct{}{}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// extractGroupsAndRoles extracts group and role information from token claims.
|
|
// It parses the 'groups' and 'roles' claims from the ID token and validates their format.
|
|
// Parameters:
|
|
// - idToken: The ID token containing claims to extract.
|
|
//
|
|
// Returns:
|
|
// - groups: Array of group names from the 'groups' claim.
|
|
// - roles: Array of role names from the 'roles' claim.
|
|
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present
|
|
// but not arrays of strings.
|
|
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
|
claims, err := t.extractClaimsFunc(idToken)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
|
|
}
|
|
|
|
var groups []string
|
|
var roles []string
|
|
|
|
if groupsClaim, exists := claims["groups"]; exists {
|
|
groupsSlice, ok := groupsClaim.([]interface{})
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("groups claim is not an array")
|
|
} else {
|
|
for _, group := range groupsSlice {
|
|
if groupStr, ok := group.(string); ok {
|
|
t.logger.Debugf("Found group: %s", groupStr)
|
|
groups = append(groups, groupStr)
|
|
} else {
|
|
t.logger.Errorf("Non-string value found in groups claim array: %v", group)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if rolesClaim, exists := claims["roles"]; exists {
|
|
rolesSlice, ok := rolesClaim.([]interface{})
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("roles claim is not an array")
|
|
} else {
|
|
for _, role := range rolesSlice {
|
|
if roleStr, ok := role.(string); ok {
|
|
t.logger.Debugf("Found role: %s", roleStr)
|
|
roles = append(roles, roleStr)
|
|
} else {
|
|
t.logger.Errorf("Non-string value found in roles claim array: %v", role)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return groups, roles, nil
|
|
}
|
|
|
|
// buildFullURL constructs a complete URL from scheme, host, and path components.
|
|
// It handles absolute URLs in the path and ensures proper URL formatting.
|
|
// Parameters:
|
|
// - scheme: The URL scheme ("http" or "https").
|
|
// - host: The host name and optional port.
|
|
// - path: The path component (may be absolute URL itself).
|
|
//
|
|
// Returns:
|
|
// - The combined absolute URL string (e.g., "https://example.com:8080/resource").
|
|
func buildFullURL(scheme, host, path string) string {
|
|
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
|
return path
|
|
}
|
|
|
|
if !strings.HasPrefix(path, "/") {
|
|
path = "/" + path
|
|
}
|
|
|
|
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
|
}
|
|
|
|
// ExchangeCodeForToken exchanges an authorization code for tokens.
|
|
// This is a wrapper method that delegates to the internal token exchange logic
|
|
// while still allowing mocking for tests.
|
|
// Parameters:
|
|
// - ctx: The request context.
|
|
// - grantType: The OAuth 2.0 grant type ("authorization_code").
|
|
// - codeOrToken: The authorization code received from the provider.
|
|
// - redirectURL: The redirect URI used in the authorization request.
|
|
// - codeVerifier: The PKCE code verifier (if PKCE is enabled).
|
|
//
|
|
// Returns:
|
|
// - The token response containing access token, ID token, and refresh token.
|
|
// - An error if the token exchange fails.
|
|
func (t *TraefikOidc) ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return t.exchangeTokens(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
|
|
}
|
|
|
|
// GetNewTokenWithRefreshToken refreshes tokens using a refresh token.
|
|
// This is a wrapper method that delegates to the internal refresh token logic
|
|
// while still allowing mocking for tests.
|
|
// Parameters:
|
|
// - refreshToken: The refresh token to use for obtaining new tokens.
|
|
//
|
|
// Returns:
|
|
// - The token response containing new access token, ID token, and potentially new refresh token.
|
|
// - An error if the refresh fails.
|
|
func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
|
return t.getNewTokenWithRefreshToken(refreshToken)
|
|
}
|
|
|
|
// sendErrorResponse sends an appropriate error response based on the request's Accept header.
|
|
// It sends JSON responses for clients that accept JSON, otherwise sends HTML error pages.
|
|
// Parameters:
|
|
// - rw: The HTTP response writer.
|
|
// - req: The HTTP request (used to check Accept header).
|
|
// - message: The error message to display.
|
|
// - code: The HTTP status code to set for the response.
|
|
func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
|
|
acceptHeader := req.Header.Get("Accept")
|
|
|
|
if strings.Contains(acceptHeader, "application/json") {
|
|
t.logger.Debugf("Sending JSON error response (code %d): %s", code, message)
|
|
rw.Header().Set("Content-Type", "application/json")
|
|
rw.WriteHeader(code)
|
|
json.NewEncoder(rw).Encode(map[string]interface{}{
|
|
"error": http.StatusText(code),
|
|
"error_description": message,
|
|
"status_code": code,
|
|
})
|
|
return
|
|
}
|
|
|
|
t.logger.Debugf("Sending HTML error response (code %d): %s", code, message)
|
|
|
|
returnURL := "/"
|
|
|
|
htmlBody := fmt.Sprintf(`
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Authentication Error</title>
|
|
<style>
|
|
body { font-family: sans-serif; padding: 20px; background-color: #f8f9fa; color: #343a40; }
|
|
h1 { color: #dc3545; }
|
|
a { color: #007bff; text-decoration: none; }
|
|
a:hover { text-decoration: underline; }
|
|
.container { max-width: 600px; margin: auto; background: #fff; padding: 20px; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div class="container">
|
|
<h1>Authentication Error</h1>
|
|
<p>%s</p>
|
|
<p><a href="%s">Return to application</a></p>
|
|
</div>
|
|
</body>
|
|
</html>`, message, returnURL)
|
|
|
|
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
rw.WriteHeader(code)
|
|
_, _ = rw.Write([]byte(htmlBody))
|
|
}
|
|
|
|
// isGoogleProvider detects if the configured OIDC provider is Google.
|
|
// It checks the issuer URL for Google-specific domains.
|
|
// Returns:
|
|
// - true if the provider is Google, false otherwise.
|
|
func (t *TraefikOidc) isGoogleProvider() bool {
|
|
return strings.Contains(t.issuerURL, "google") || strings.Contains(t.issuerURL, "accounts.google.com")
|
|
}
|
|
|
|
// isAzureProvider detects if the configured OIDC provider is Azure AD.
|
|
// It checks the issuer URL for Microsoft Azure AD domains.
|
|
// Returns:
|
|
// - true if the provider is Azure AD, false otherwise.
|
|
func (t *TraefikOidc) isAzureProvider() bool {
|
|
return strings.Contains(t.issuerURL, "login.microsoftonline.com") ||
|
|
strings.Contains(t.issuerURL, "sts.windows.net") ||
|
|
strings.Contains(t.issuerURL, "login.windows.net")
|
|
}
|
|
|
|
// validateAzureTokens validates tokens with Azure AD-specific logic.
|
|
// Azure tokens may be opaque access tokens that cannot be verified as JWTs,
|
|
// so this method handles both JWT and opaque token scenarios.
|
|
// Parameters:
|
|
// - session: The session data containing tokens to validate.
|
|
//
|
|
// Returns:
|
|
// - authenticated: Whether the user has valid authentication.
|
|
// - needsRefresh: Whether tokens need to be refreshed.
|
|
// - expired: Whether tokens have expired and cannot be refreshed.
|
|
func (t *TraefikOidc) validateAzureTokens(session *SessionData) (bool, bool, bool) {
|
|
if !session.GetAuthenticated() {
|
|
t.logger.Debug("Azure user is not authenticated according to session flag")
|
|
if session.GetRefreshToken() != "" {
|
|
t.logger.Debug("Azure session not authenticated, but refresh token exists. Signaling need for refresh.")
|
|
return false, true, false
|
|
}
|
|
return false, true, false
|
|
}
|
|
|
|
accessToken := session.GetAccessToken()
|
|
idToken := session.GetIDToken()
|
|
|
|
if accessToken != "" {
|
|
if strings.Count(accessToken, ".") == 2 {
|
|
if err := t.verifyToken(accessToken); err != nil {
|
|
if idToken != "" {
|
|
if err := t.verifyToken(idToken); err != nil {
|
|
t.logger.Debugf("Azure: Both access and ID token validation failed: %v", err)
|
|
if session.GetRefreshToken() != "" {
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
return t.validateTokenExpiry(session, idToken)
|
|
}
|
|
if session.GetRefreshToken() != "" {
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
return t.validateTokenExpiry(session, accessToken)
|
|
} else {
|
|
t.logger.Debug("Azure access token appears opaque, treating as valid")
|
|
if idToken != "" {
|
|
return t.validateTokenExpiry(session, idToken)
|
|
}
|
|
return true, false, false
|
|
}
|
|
}
|
|
|
|
if idToken != "" {
|
|
if err := t.verifyToken(idToken); err != nil {
|
|
if strings.Contains(err.Error(), "token has expired") {
|
|
if session.GetRefreshToken() != "" {
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
if session.GetRefreshToken() != "" {
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
return t.validateTokenExpiry(session, idToken)
|
|
}
|
|
|
|
if session.GetRefreshToken() != "" {
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
|
|
// validateGoogleTokens handles Google-specific token validation logic.
|
|
// Currently delegates to standard token validation but provides a hook
|
|
// for Google-specific validation requirements in the future.
|
|
// Parameters:
|
|
// - session: The session data containing tokens to validate.
|
|
//
|
|
// Returns:
|
|
// - authenticated: Whether the user has valid authentication.
|
|
// - needsRefresh: Whether tokens need to be refreshed.
|
|
// - expired: Whether tokens have expired and cannot be refreshed.
|
|
func (t *TraefikOidc) validateGoogleTokens(session *SessionData) (bool, bool, bool) {
|
|
return t.validateStandardTokens(session)
|
|
}
|
|
|
|
// validateStandardTokens handles standard OIDC token validation logic.
|
|
// This is the default validation method for generic OIDC providers.
|
|
// It verifies ID tokens and handles access tokens appropriately.
|
|
// Parameters:
|
|
// - session: The session data containing tokens to validate.
|
|
//
|
|
// Returns:
|
|
// - authenticated: Whether the user has valid authentication.
|
|
// - needsRefresh: Whether tokens need to be refreshed.
|
|
// - expired: Whether tokens have expired and cannot be refreshed.
|
|
func (t *TraefikOidc) validateStandardTokens(session *SessionData) (bool, bool, bool) {
|
|
authenticated := session.GetAuthenticated()
|
|
// Removed debug output
|
|
if !authenticated {
|
|
t.logger.Debug("User is not authenticated according to session flag")
|
|
if session.GetRefreshToken() != "" {
|
|
t.logger.Debug("Session not authenticated, but refresh token exists. Signaling need for refresh.")
|
|
return false, true, false
|
|
}
|
|
return false, false, false
|
|
}
|
|
|
|
accessToken := session.GetAccessToken()
|
|
// Removed debug output
|
|
if accessToken == "" {
|
|
t.logger.Debug("Authenticated flag set, but no access token found in session")
|
|
if session.GetRefreshToken() != "" {
|
|
// Check if we have an ID token to determine if we're beyond grace period
|
|
// When access token is missing, check ID token expiry to determine if refresh is viable
|
|
idToken := session.GetIDToken()
|
|
t.logger.Debugf("Checking ID token for grace period: ID token present: %v", idToken != "")
|
|
if idToken != "" {
|
|
// Try to parse the ID token to check its expiry
|
|
parts := strings.Split(idToken, ".")
|
|
if len(parts) == 3 {
|
|
// Decode the claims part
|
|
claimsData, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err == nil {
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(claimsData, &claims); err == nil {
|
|
if expClaim, ok := claims["exp"].(float64); ok {
|
|
expTime := time.Unix(int64(expClaim), 0)
|
|
if time.Now().After(expTime) {
|
|
expiredDuration := time.Since(expTime)
|
|
if expiredDuration > t.refreshGracePeriod {
|
|
t.logger.Debugf("ID token expired beyond grace period (%v > %v), must re-authenticate",
|
|
expiredDuration, t.refreshGracePeriod)
|
|
return false, false, true // expired, cannot refresh
|
|
}
|
|
t.logger.Debugf("ID token expired %v ago, within grace period %v, allowing refresh",
|
|
expiredDuration, t.refreshGracePeriod)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
t.logger.Debug("Access token missing, but refresh token exists. Signaling need for refresh.")
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
|
|
idToken := session.GetIDToken()
|
|
if idToken == "" {
|
|
t.logger.Debug("Authenticated flag set with access token, but no ID token found in session (possibly opaque token)")
|
|
session.SetAuthenticated(true)
|
|
|
|
if session.GetRefreshToken() != "" {
|
|
t.logger.Debug("ID token missing but refresh token exists. Signaling conditional refresh to obtain ID token.")
|
|
return true, true, false
|
|
}
|
|
return true, false, false
|
|
}
|
|
|
|
if err := t.verifyToken(idToken); err != nil {
|
|
if strings.Contains(err.Error(), "token has expired") {
|
|
t.logger.Debugf("ID token signature/claims valid but token expired, needs refresh")
|
|
if session.GetRefreshToken() != "" {
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
|
|
t.logger.Errorf("ID token verification failed (non-expiration): %v", err)
|
|
if session.GetRefreshToken() != "" {
|
|
t.logger.Debug("ID token verification failed, but refresh token exists. Signaling need for refresh.")
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
|
|
return t.validateTokenExpiry(session, idToken)
|
|
}
|
|
|
|
// validateTokenExpiry checks if a token is nearing expiration and needs refresh.
|
|
// It uses the configured grace period to determine when proactive refresh should occur.
|
|
// Parameters:
|
|
// - session: The session data for refresh token availability.
|
|
// - token: The token to check expiry for.
|
|
//
|
|
// Returns:
|
|
// - authenticated: Whether the token is currently valid.
|
|
// - needsRefresh: Whether the token is nearing expiration and should be refreshed.
|
|
// - expired: Whether the token is invalid or verification failed.
|
|
func (t *TraefikOidc) validateTokenExpiry(session *SessionData, token string) (bool, bool, bool) {
|
|
cachedClaims, found := t.tokenCache.Get(token)
|
|
if !found {
|
|
t.logger.Debug("Claims not found in cache after successful token verification")
|
|
if session.GetRefreshToken() != "" {
|
|
t.logger.Debug("Claims missing post-verification, attempting refresh to recover.")
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
|
|
expClaim, ok := cachedClaims["exp"].(float64)
|
|
if !ok {
|
|
t.logger.Error("Failed to get expiration time ('exp' claim) from verified token")
|
|
if session.GetRefreshToken() != "" {
|
|
t.logger.Debug("Token missing 'exp' claim, but refresh token exists. Signaling need for refresh.")
|
|
return false, true, false
|
|
}
|
|
return false, false, true
|
|
}
|
|
|
|
expTime := int64(expClaim)
|
|
expTimeObj := time.Unix(expTime, 0)
|
|
nowObj := time.Now()
|
|
|
|
// Check if token has already expired
|
|
if expTimeObj.Before(nowObj) {
|
|
// Token has expired
|
|
expiredDuration := nowObj.Sub(expTimeObj)
|
|
|
|
t.logger.Debugf("Token expired %v ago, grace period is %v",
|
|
expiredDuration, t.refreshGracePeriod)
|
|
|
|
// If we have a refresh token, always attempt to use it regardless of grace period
|
|
// The refresh token has its own expiry and the provider will reject it if invalid
|
|
if session.GetRefreshToken() != "" {
|
|
t.logger.Debugf("Token expired, attempting refresh with available refresh token")
|
|
return false, true, false // needs refresh
|
|
}
|
|
|
|
// No refresh token available - must re-authenticate
|
|
t.logger.Debugf("Token expired and no refresh token available, must re-authenticate")
|
|
return false, false, true // expired, cannot refresh
|
|
}
|
|
|
|
// Token not yet expired - check if nearing expiration
|
|
refreshThreshold := nowObj.Add(t.refreshGracePeriod)
|
|
|
|
t.logger.Debugf("Token expires at %v, now is %v, refresh threshold is %v",
|
|
expTimeObj.Format(time.RFC3339),
|
|
nowObj.Format(time.RFC3339),
|
|
refreshThreshold.Format(time.RFC3339))
|
|
|
|
if expTimeObj.Before(refreshThreshold) {
|
|
remainingSeconds := int64(time.Until(expTimeObj).Seconds())
|
|
t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh",
|
|
remainingSeconds, t.refreshGracePeriod)
|
|
|
|
if session.GetRefreshToken() != "" {
|
|
return true, true, false
|
|
}
|
|
|
|
t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.")
|
|
return true, false, false
|
|
}
|
|
|
|
t.logger.Debugf("Token is valid and not nearing expiration (expires in %d seconds, outside %s grace period)",
|
|
int64(time.Until(expTimeObj).Seconds()), t.refreshGracePeriod)
|
|
|
|
return true, false, false
|
|
}
|
|
|
|
// Close gracefully shuts down the TraefikOidc middleware instance.
|
|
// It cancels contexts, stops background goroutines, closes HTTP connections,
|
|
// cleans up caches, and releases all resources. Safe to call multiple times.
|
|
// Returns:
|
|
// - An error if shutdown times out or resource cleanup fails.
|
|
func (t *TraefikOidc) Close() error {
|
|
var closeErr error
|
|
t.shutdownOnce.Do(func() {
|
|
t.safeLogDebug("Closing TraefikOidc plugin instance")
|
|
|
|
// Get resource manager for cleanup
|
|
rm := GetResourceManager()
|
|
|
|
// Stop singleton tasks related to this instance
|
|
rm.StopBackgroundTask("singleton-token-cleanup")
|
|
rm.StopBackgroundTask("singleton-metadata-refresh")
|
|
|
|
// Remove reference for this instance
|
|
rm.RemoveReference(t.name)
|
|
|
|
if t.cancelFunc != nil {
|
|
t.cancelFunc()
|
|
t.safeLogDebug("Context cancellation signaled to all goroutines")
|
|
}
|
|
|
|
// Clean up legacy stop channels if they exist
|
|
if t.tokenCleanupStopChan != nil {
|
|
close(t.tokenCleanupStopChan)
|
|
t.safeLogDebug("tokenCleanupStopChan closed")
|
|
}
|
|
if t.metadataRefreshStopChan != nil {
|
|
close(t.metadataRefreshStopChan)
|
|
t.safeLogDebug("metadataRefreshStopChan closed")
|
|
}
|
|
|
|
if t.goroutineWG != nil {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
t.goroutineWG.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
t.safeLogDebug("All background goroutines stopped gracefully")
|
|
case <-time.After(10 * time.Second):
|
|
t.safeLogError("Timeout waiting for background goroutines to stop")
|
|
}
|
|
} else {
|
|
t.safeLogDebug("No goroutineWG to wait for (likely in test)")
|
|
}
|
|
|
|
if t.httpClient != nil {
|
|
if transport, ok := t.httpClient.Transport.(*http.Transport); ok {
|
|
transport.CloseIdleConnections()
|
|
t.safeLogDebug("HTTP client idle connections closed")
|
|
}
|
|
}
|
|
|
|
if t.tokenHTTPClient != nil {
|
|
if transport, ok := t.tokenHTTPClient.Transport.(*http.Transport); ok {
|
|
transport.CloseIdleConnections()
|
|
t.safeLogDebug("Token HTTP client idle connections closed")
|
|
}
|
|
if t.tokenHTTPClient.Transport != t.httpClient.Transport {
|
|
if transport, ok := t.tokenHTTPClient.Transport.(*http.Transport); ok {
|
|
transport.CloseIdleConnections()
|
|
t.safeLogDebug("Token HTTP client transport closed (separate from main)")
|
|
}
|
|
}
|
|
}
|
|
|
|
if t.tokenBlacklist != nil {
|
|
t.tokenBlacklist.Close()
|
|
t.safeLogDebug("tokenBlacklist closed")
|
|
}
|
|
if t.metadataCache != nil {
|
|
t.metadataCache.Close()
|
|
t.safeLogDebug("metadataCache closed")
|
|
}
|
|
if t.tokenCache != nil {
|
|
t.tokenCache.Close()
|
|
t.safeLogDebug("tokenCache closed")
|
|
}
|
|
|
|
if t.jwkCache != nil {
|
|
t.jwkCache.Close()
|
|
t.safeLogDebug("t.jwkCache.Close() called as per original instruction.")
|
|
}
|
|
|
|
// Shutdown session manager and its background cleanup routines
|
|
if t.sessionManager != nil {
|
|
if err := t.sessionManager.Shutdown(); err != nil {
|
|
t.safeLogErrorf("Error shutting down session manager: %v", err)
|
|
} else {
|
|
t.safeLogDebug("sessionManager shutdown completed")
|
|
}
|
|
}
|
|
|
|
// Clean up error recovery manager
|
|
if t.errorRecoveryManager != nil && t.errorRecoveryManager.gracefulDegradation != nil {
|
|
t.errorRecoveryManager.gracefulDegradation.Close()
|
|
t.safeLogDebug("Error recovery manager graceful degradation closed")
|
|
}
|
|
|
|
// Stop all global background tasks
|
|
taskRegistry := GetGlobalTaskRegistry()
|
|
taskRegistry.StopAllTasks()
|
|
t.safeLogDebug("All global background tasks stopped")
|
|
|
|
CleanupGlobalMemoryPools()
|
|
t.safeLogDebug("Global memory pools cleaned up")
|
|
|
|
// Force garbage collection to help with memory cleanup after shutdown
|
|
runtime.GC()
|
|
t.safeLogDebug("Forced garbage collection after shutdown")
|
|
|
|
t.safeLogInfo("TraefikOidc plugin instance closed successfully.")
|
|
})
|
|
return closeErr
|
|
}
|
|
|
|
// isAjaxRequest determines if the request is an AJAX/fetch request that should
|
|
// receive JSON responses instead of HTML redirects.
|
|
// Returns true if the request contains AJAX indicators.
|
|
func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
|
|
// Check for XMLHttpRequest header (set by jQuery and many AJAX libraries)
|
|
if req.Header.Get("X-Requested-With") == "XMLHttpRequest" {
|
|
return true
|
|
}
|
|
|
|
// Check if client prefers JSON response
|
|
acceptHeader := req.Header.Get("Accept")
|
|
if strings.Contains(acceptHeader, "application/json") {
|
|
return true
|
|
}
|
|
|
|
// Check for fetch API requests (often contain these headers)
|
|
if req.Header.Get("Sec-Fetch-Mode") == "cors" {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// isRefreshTokenExpired checks if the refresh token is likely expired based on
|
|
// when it was last obtained. Refresh tokens typically expire after 6+ hours.
|
|
// Returns true if the refresh token is likely expired and refresh should be skipped.
|
|
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
|
refreshTokenIssuedAt := session.GetRefreshTokenIssuedAt()
|
|
if refreshTokenIssuedAt.IsZero() {
|
|
// If we don't have issue time, assume it might be old but try refresh anyway
|
|
return false
|
|
}
|
|
|
|
// Consider refresh token expired if it's older than 6 hours
|
|
// This is a conservative estimate as most providers use 6-24 hour expiry
|
|
refreshTokenMaxAge := 6 * time.Hour
|
|
return time.Since(refreshTokenIssuedAt) > refreshTokenMaxAge
|
|
}
|