Update documentation to the higher standards.

This commit is contained in:
2025-04-05 11:31:45 +01:00
parent 46c2f98a15
commit 1910cd6000
10 changed files with 1266 additions and 256 deletions
+10 -2
View File
@@ -2,8 +2,16 @@ package traefikoidc
import "time"
// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval.
// It stops when a value is received on the stop channel.
// autoCleanupRoutine periodically calls the provided cleanup function.
// It starts a ticker with the given interval and executes the cleanup function
// on each tick. The routine stops gracefully when a signal is received on the
// stop channel. This is typically used for background cleanup tasks like
// expiring cache entries.
//
// Parameters:
// - interval: The time duration between cleanup calls.
// - stop: A channel used to signal the routine to stop. Receiving any value will terminate the loop.
// - cleanup: The function to call periodically for cleanup tasks.
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
+34 -12
View File
@@ -46,7 +46,9 @@ type Cache struct {
// DefaultMaxSize is the default maximum number of items in the cache.
const DefaultMaxSize = 500
// NewCache creates a new empty cache instance that is ready for use.
// NewCache creates a new empty cache instance with default settings.
// It initializes the internal maps and list, sets the default maximum size,
// and starts the automatic cleanup goroutine.
func NewCache() *Cache {
c := &Cache{
items: make(map[string]CacheItem, DefaultMaxSize),
@@ -60,8 +62,12 @@ func NewCache() *Cache {
return c
}
// Set adds or updates an item in the cache with the specified expiration duration.
// It moves the item to the most recently used position.
// Set adds or updates an item in the cache with the specified key, value, and expiration duration.
// If the key already exists, its value and expiration time are updated, and it's moved
// to the most recently used position in the LRU list.
// If the key does not exist and the cache is full, the least recently used item is evicted
// before adding the new item.
// The expiration duration is relative to the time Set is called.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -95,8 +101,11 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.elems[key] = elem
}
// Get retrieves an item from the cache if it exists and hasn't expired.
// Moving the accessed item to the most recently used position.
// Get retrieves an item from the cache by its key.
// If the item exists and has not expired, its value and true are returned.
// Accessing an item moves it to the most recently used position in the LRU list.
// If the item does not exist or has expired, nil and false are returned, and the
// expired item is removed from the cache.
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -120,7 +129,9 @@ func (c *Cache) Get(key string) (interface{}, bool) {
return item.Value, true
}
// Delete removes an item from the cache.
// Delete removes an item from the cache by its key.
// If the key exists, the corresponding item is removed from the cache storage
// and the LRU list.
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -128,8 +139,10 @@ func (c *Cache) Delete(key string) {
c.removeItem(key)
}
// Cleanup removes all expired items from the cache. This should be called periodically
// to prevent memory bloat from expired entries.
// Cleanup iterates through the cache and removes all items that have expired.
// An item is considered expired if the current time is after its ExpiresAt timestamp.
// This method is called automatically by the auto-cleanup goroutine, but can also
// be called manually.
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -143,7 +156,11 @@ func (c *Cache) Cleanup() {
}
}
// evictOldest removes the least recently used item from the cache.
// evictOldest removes the least recently used (oldest) item from the cache.
// It first attempts to find and remove an expired item from the front of the LRU list.
// If no expired items are found at the front, it removes the absolute oldest item (front of the list).
// This method is called internally by Set when the cache reaches its maximum size.
// Note: This function assumes the write lock is already held.
func (c *Cache) evictOldest() {
now := time.Now()
elem := c.order.Front()
@@ -167,7 +184,9 @@ func (c *Cache) evictOldest() {
}
}
// removeItem removes an item from both the cache and the LRU tracking structures.
// removeItem removes an item specified by the key from the cache's internal storage (items map)
// and its corresponding entry from the LRU list (order list and elems map).
// Note: This function assumes the write lock is already held.
func (c *Cache) removeItem(key string) {
delete(c.items, key)
if elem, ok := c.elems[key]; ok {
@@ -176,12 +195,15 @@ func (c *Cache) removeItem(key string) {
}
}
// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items.
// startAutoCleanup starts the background goroutine that automatically calls the Cleanup method
// at the interval specified by c.autoCleanupInterval.
// It uses the autoCleanupRoutine helper function.
func (c *Cache) startAutoCleanup() {
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
}
// Close terminates the auto cleanup goroutine.
// Close stops the automatic cleanup goroutine associated with this cache instance.
// It should be called when the cache is no longer needed to prevent resource leaks.
func (c *Cache) Close() {
close(c.stopCleanup)
}
+129 -38
View File
@@ -15,10 +15,14 @@ import (
"time"
)
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
// the authentication request.
// generateNonce creates a cryptographically secure random string suitable for use as an OIDC nonce.
// The nonce is used during the authentication flow to mitigate replay attacks by associating
// the ID token with the specific authentication request.
// It generates 32 random bytes and encodes them using base64 URL encoding.
//
// Returns:
// - A base64 URL encoded random string (nonce).
// - An error if the random byte generation fails.
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
@@ -28,9 +32,13 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// generateCodeVerifier creates a cryptographically secure random string
// for use as a PKCE code verifier. The code verifier must be between 43 and 128
// characters long, per the PKCE spec (RFC 7636).
// generateCodeVerifier creates a cryptographically secure random string suitable for use as a PKCE code verifier.
// According to RFC 7636, the verifier should be a high-entropy string between 43 and 128 characters long.
// This function generates 32 random bytes, resulting in a 43-character base64 URL encoded string.
//
// Returns:
// - A base64 URL encoded random string (code verifier).
// - An error if the random byte generation fails.
func generateCodeVerifier() (string, error) {
// Using 32 bytes (256 bits) will produce a 43 character base64url string
verifierBytes := make([]byte, 32)
@@ -41,8 +49,15 @@ func generateCodeVerifier() (string, error) {
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
}
// deriveCodeChallenge creates a code challenge from a code verifier
// using the SHA-256 method as specified in the PKCE standard (RFC 7636).
// deriveCodeChallenge computes the PKCE code challenge from a given code verifier.
// It uses the S256 challenge method (SHA-256 hash followed by base64 URL encoding)
// as defined in RFC 7636.
//
// Parameters:
// - codeVerifier: The high-entropy string generated by generateCodeVerifier.
//
// Returns:
// - The base64 URL encoded SHA-256 hash of the code verifier (code challenge).
func deriveCodeChallenge(codeVerifier string) string {
// Calculate SHA-256 hash of the code verifier
hasher := sha256.New()
@@ -73,14 +88,22 @@ type TokenResponse struct {
TokenType string `json:"token_type"`
}
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
// It supports both authorization code and refresh token grant types.
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint.
// It handles both the "authorization_code" grant type (exchanging an authorization code for tokens)
// and the "refresh_token" grant type (using a refresh token to obtain new tokens).
// It includes necessary parameters like client credentials and handles PKCE verification if applicable.
// The function follows redirects and handles potential errors during the exchange.
//
// Parameters:
// - ctx: Context for the HTTP request
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
// - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant
// - codeVerifier: Optional PKCE code verifier for authorization code grant
// - ctx: The context for the outgoing HTTP request.
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token").
// - codeOrToken: The authorization code (for "authorization_code" grant) or the refresh token (for "refresh_token" grant).
// - redirectURL: The redirect URI that was used in the initial authorization request (required for "authorization_code" grant).
// - codeVerifier: The PKCE code verifier (required for "authorization_code" grant if PKCE was used).
//
// Returns:
// - A TokenResponse containing the obtained tokens (ID, access, refresh).
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant).
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
@@ -140,8 +163,16 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
return &tokenResponse, nil
}
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
// This is used to refresh access tokens before they expire.
// getNewTokenWithRefreshToken uses a refresh token to obtain a new set of tokens (ID, access, refresh)
// from the OIDC provider's token endpoint. It wraps the exchangeTokens function with the
// "refresh_token" grant type.
//
// Parameters:
// - refreshToken: The refresh token previously obtained during authentication or a prior refresh.
//
// Returns:
// - A TokenResponse containing the newly obtained tokens.
// - An error if the refresh operation fails.
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
@@ -153,8 +184,17 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
return tokenResponse, nil
}
// extractClaims parses a JWT token and extracts its claims.
// It handles base64url decoding and JSON parsing of the token payload.
// extractClaims decodes the payload (claims set) part of a JWT string.
// It splits the JWT into its three parts, base64 URL decodes the second part (payload),
// and unmarshals the resulting JSON into a map.
// Note: This function does *not* validate the token's signature or claims.
//
// Parameters:
// - tokenString: The raw JWT string.
//
// Returns:
// - A map representing the JSON claims extracted from the token payload.
// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails.
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -182,21 +222,36 @@ type TokenCache struct {
cache *Cache
}
// NewTokenCache creates a new TokenCache instance.
// NewTokenCache creates and initializes a new TokenCache.
// It internally creates a new generic Cache instance for storage.
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
}
}
// Set stores a token's claims in the cache with an expiration time.
// Set stores the claims associated with a specific token string in the cache.
// It prefixes the token string to avoid potential collisions with other cache types
// and sets the provided expiration duration.
//
// Parameters:
// - token: The raw token string (used as the key).
// - claims: The map of claims associated with the token.
// - expiration: The duration for which the cache entry should be valid.
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
}
// Get retrieves a token's claims from the cache.
// Returns the claims and a boolean indicating if the token was found.
// Get retrieves the cached claims for a given token string.
// It prefixes the token string before querying the underlying cache.
//
// Parameters:
// - token: The raw token string to look up.
//
// Returns:
// - The cached claims map if found and valid.
// - A boolean indicating whether the token was found in the cache (true if found, false otherwise).
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
@@ -207,20 +262,34 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
return claims, ok
}
// Delete removes a token from the cache.
// Delete removes the cached entry for a specific token string.
// It prefixes the token string before calling the underlying cache's Delete method.
//
// Parameters:
// - token: The raw token string to remove from the cache.
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
}
// Cleanup removes expired tokens from the cache.
// Cleanup triggers the cleanup process for the underlying generic cache,
// removing expired token entries.
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
// exchangeCodeForToken exchanges an authorization code for tokens.
// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration.
// The code verifier is only included in the token request if PKCE is enabled.
// exchangeCodeForToken is a convenience function that wraps exchangeTokens specifically
// for the "authorization_code" grant type. It handles the conditional inclusion of the
// PKCE code verifier based on the middleware's configuration (t.enablePKCE).
//
// Parameters:
// - code: The authorization code received from the OIDC provider.
// - redirectURL: The redirect URI used in the initial authorization request.
// - codeVerifier: The PKCE code verifier stored in the session (if PKCE is enabled).
//
// Returns:
// - A TokenResponse containing the obtained tokens.
// - An error if the code exchange fails.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
ctx := context.Background()
@@ -237,8 +306,15 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, code
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings.
// Used for efficient lookups in allowed domains and roles.
// createStringMap converts a slice of strings into a map[string]struct{} (a set).
// This is useful for creating efficient lookups (O(1) average time complexity)
// for checking the presence of items like allowed domains, roles, or groups.
//
// Parameters:
// - keys: A slice of strings to be added to the set.
//
// Returns:
// - A map where the keys are the strings from the input slice and the values are empty structs.
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
for _, key := range keys {
@@ -247,9 +323,17 @@ func createStringMap(keys []string) map[string]struct{} {
return result
}
// handleLogout manages the OIDC logout process.
// It clears the session and redirects either to the OIDC provider's
// end session endpoint (if available) or to the configured post-logout URL.
// handleLogout processes requests to the configured logout path.
// It performs the following steps:
// 1. Retrieves the current user session.
// 2. Gets the access token (ID token hint) from the session.
// 3. Clears all authentication-related data from the session cookies.
// 4. Determines the final post-logout redirect URI.
// 5. If an OIDC end_session_endpoint is configured and an ID token hint is available,
// it builds the OIDC logout URL and redirects the user agent to the provider for logout.
// 6. Otherwise, it redirects the user agent directly to the post-logout redirect URI.
//
// It handles potential errors during session retrieval or clearing.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
@@ -291,11 +375,18 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
// BuildLogoutURL constructs the URL for redirecting the user agent to the OIDC provider's
// end_session_endpoint, including the required id_token_hint and optional
// post_logout_redirect_uri parameters as query arguments.
//
// Parameters:
// - endSessionURL: The OIDC provider's end session endpoint
// - idToken: The ID token to be invalidated
// - postLogoutRedirectURI: Where to redirect after logout completes
// - endSessionURL: The URL of the OIDC provider's end session endpoint.
// - idToken: The ID token previously issued to the user (used as id_token_hint).
// - postLogoutRedirectURI: The optional URI where the provider should redirect the user agent after logout.
//
// Returns:
// - The fully constructed logout URL string.
// - An error if the provided endSessionURL is invalid.
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
+61
View File
@@ -45,6 +45,21 @@ type JWKCacheInterface interface {
Cleanup()
}
// GetJWKS retrieves the JSON Web Key Set (JWKS) from the cache or fetches it from the provider.
// It first checks if a valid, non-expired JWKS is present in the cache. If so, it returns the cached version.
// Otherwise, it attempts to fetch the JWKS from the specified jwksURL using the provided httpClient.
// If the fetch is successful, the JWKS is stored in the cache with an expiration time based on CacheLifetime
// (defaulting to 1 hour if not set) and returned.
// This method uses double-checked locking to minimize contention when the cache needs refreshing.
//
// Parameters:
// - ctx: Context for the HTTP request if fetching is required.
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
// - httpClient: The HTTP client to use for fetching the JWKS.
//
// Returns:
// - A pointer to the JWKSet containing the keys.
// - An error if fetching fails or the response cannot be decoded.
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -74,6 +89,8 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
return jwks, nil
}
// Cleanup removes the cached JWKS if it has expired.
// This is intended to be called periodically to ensure stale JWKS data is cleared.
func (c *JWKCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -84,6 +101,17 @@ func (c *JWKCache) Cleanup() {
}
}
// fetchJWKS retrieves the JSON Web Key Set (JWKS) from the specified URL.
// It uses the provided context and HTTP client to make the request.
//
// Parameters:
// - ctx: Context for the HTTP request.
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
// - httpClient: The HTTP client to use for the request.
//
// Returns:
// - A pointer to the fetched JWKSet.
// - An error if the request fails, the status code is not OK, or the response body cannot be decoded.
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
// Create a request with context to enforce timeout
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
@@ -109,6 +137,16 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
return &jwks, nil
}
// jwkToPEM converts a JWK (JSON Web Key) object into PEM (Privacy-Enhanced Mail) format.
// It selects the appropriate conversion function based on the JWK's key type ("kty").
// Currently supports "RSA" and "EC" key types.
//
// Parameters:
// - jwk: A pointer to the JWK object to convert.
//
// Returns:
// - A byte slice containing the public key in PEM format.
// - An error if the key type is unsupported or conversion fails.
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
@@ -124,6 +162,17 @@ var jwkConverters = map[string]jwkToPEMConverter{
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JWK into PEM format.
// It decodes the modulus (n) and exponent (e) from base64 URL encoding,
// constructs an rsa.PublicKey, marshals it into PKIX format, and then
// encodes it as a PEM block.
//
// Parameters:
// - jwk: A pointer to the RSA JWK object (must have "kty": "RSA").
//
// Returns:
// - A byte slice containing the RSA public key in PEM format.
// - An error if decoding parameters fails or key marshaling fails.
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
@@ -155,6 +204,18 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
return pubKeyPEM, nil
}
// ecJWKToPEM converts an EC (Elliptic Curve) JWK into PEM format.
// It decodes the X and Y coordinates from base64 URL encoding, determines the
// elliptic curve based on the "crv" parameter (P-256, P-384, P-521),
// constructs an ecdsa.PublicKey, marshals it into PKIX format, and then
// encodes it as a PEM block.
//
// Parameters:
// - jwk: A pointer to the EC JWK object (must have "kty": "EC").
//
// Returns:
// - A byte slice containing the EC public key in PEM format.
// - An error if decoding parameters fails, the curve is unsupported, or key marshaling fails.
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
+100 -31
View File
@@ -20,6 +20,10 @@ var (
replayCache = make(map[string]time.Time)
)
// cleanupReplayCache iterates through the replay cache and removes entries
// whose expiration time is before the current time. This function should be
// called periodically to prevent the cache from growing indefinitely.
// It acquires a mutex to ensure thread safety during cleanup.
func cleanupReplayCache() {
now := time.Now()
for token, expiry := range replayCache {
@@ -48,7 +52,18 @@ type JWT struct {
Token string
}
// parseJWT parses a JWT token string into a JWT struct.
// parseJWT decodes a raw JWT string into its constituent parts: header, claims, and signature.
// It splits the token string by '.', decodes each part using base64 URL decoding,
// and unmarshals the header and claims JSON into maps. The raw signature bytes are stored.
// It performs basic format validation (expecting 3 parts).
// Note: This function does *not* validate the signature or the claims.
//
// Parameters:
// - tokenString: The raw JWT string.
//
// Returns:
// - A pointer to a JWT struct containing the decoded parts.
// - An error if the token format is invalid or decoding/unmarshaling fails.
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -84,8 +99,24 @@ func parseJWT(tokenString string) (*JWT, error) {
return jwt, nil
}
// Verify validates the standard JWT claims as defined in RFC 7519.
// Verify validates the standard JWT claims as defined in RFC 7519.
// Verify performs standard claim validation on the JWT according to RFC 7519.
// It checks the following:
// - Algorithm ('alg') is supported.
// - Issuer ('iss') matches the expected issuerURL.
// - Audience ('aud') contains the expected clientID.
// - Expiration time ('exp') is in the future (within tolerance).
// - Issued at time ('iat') is in the past (within tolerance).
// - Not before time ('nbf'), if present, is in the past (within tolerance).
// - Subject ('sub') claim exists and is not empty.
// - JWT ID ('jti'), if present, is checked against a replay cache to prevent token reuse.
//
// Parameters:
// - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com").
// - clientID: The expected audience value (the client ID of this application).
//
// Returns:
// - nil if all standard claims are valid.
// - An error describing the first validation failure encountered.
func (j *JWT) Verify(issuerURL, clientID string) error {
// Validate algorithm to prevent algorithm switching attacks
alg, ok := j.Header["alg"].(string)
@@ -175,6 +206,16 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience checks if the expected audience is present in the token's 'aud' claim.
// The 'aud' claim can be a single string or an array of strings.
//
// Parameters:
// - tokenAudience: The 'aud' claim value extracted from the token (can be string or []interface{}).
// - expectedAudience: The audience value expected for this application (client ID).
//
// Returns:
// - nil if the expected audience is found.
// - An error if the claim type is invalid or the expected audience is not present.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
@@ -198,6 +239,15 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
return nil
}
// verifyIssuer checks if the token's 'iss' claim matches the expected issuer URL.
//
// Parameters:
// - tokenIssuer: The 'iss' claim value from the token.
// - expectedIssuer: The expected issuer URL configured for the OIDC provider.
//
// Returns:
// - nil if the issuers match.
// - An error if the issuers do not match.
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
@@ -205,57 +255,76 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
return nil
}
// verifyTimeConstraint is a generic function to verify time-based claims
// verifyTimeConstraint checks time-based claims ('exp', 'iat', 'nbf') against the current time,
// allowing for configurable clock skew. It uses different tolerances for past and future checks.
//
// Parameters:
// - unixTime: The timestamp value from the claim (as a float64 Unix time).
// - claimName: The name of the claim being verified ("exp", "iat", "nbf").
// - future: A boolean indicating the direction of the check (true for 'exp', false for 'iat'/'nbf').
//
// Returns:
// - nil if the time constraint is met within the allowed tolerance.
// - An error describing the failure (e.g., "token has expired", "token used before issued").
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
claimTime := time.Unix(int64(unixTime), 0)
now := time.Now().Truncate(time.Second)
now := time.Now() // Use current time without truncation
var skewedNow time.Time
if future {
// For expiration (future=true), add skew to now (making now later)
// Use the larger tolerance for future checks (exp)
skewedNow = now.Add(ClockSkewToleranceFuture)
} else {
// For iat/nbf (future=false), subtract skew from now (making now earlier)
// Use the smaller, specific tolerance for past checks (iat, nbf)
skewedNow = now.Add(-ClockSkewTolerancePast) // Subtract the past tolerance
}
if claimTime.Equal(now) {
return nil
}
// For expiration: if skewedNow (later) is after expiration, token expired
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
var reason string
if future {
reason = "has expired"
} else {
var err error
if future { // 'exp' check
// Token is expired if Now is after (ClaimTime + FutureTolerance)
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
if now.After(allowedExpiry) {
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
}
} else { // 'iat' or 'nbf' check
// Token is invalid if Now is before (ClaimTime - PastTolerance)
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
if now.Before(allowedStart) {
reason := "not yet valid"
if claimName == "iat" {
reason = "used before issued"
} else {
reason = "not yet valid"
}
err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC())
}
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
}
return nil
return err
}
// verifyExpiration checks the 'exp' (Expiration Time) claim.
// It calls verifyTimeConstraint with future=true.
func verifyExpiration(expiration float64) error {
return verifyTimeConstraint(expiration, "exp", true)
}
// verifyIssuedAt checks the 'iat' (Issued At) claim.
// It calls verifyTimeConstraint with future=false.
func verifyIssuedAt(issuedAt float64) error {
return verifyTimeConstraint(issuedAt, "iat", false)
}
// verifyNotBefore checks the 'nbf' (Not Before) claim.
// It calls verifyTimeConstraint with future=false.
func verifyNotBefore(notBefore float64) error {
return verifyTimeConstraint(notBefore, "nbf", false)
}
// verifySignature validates the JWT's signature using the provided public key.
// It parses the public key from PEM format, selects the appropriate hashing algorithm
// based on the 'alg' parameter (SHA256/384/512), hashes the token's signing input
// (header + "." + payload), and then verifies the signature against the hash using
// the corresponding RSA (PKCS1v15 or PSS) or ECDSA verification method.
//
// Parameters:
// - tokenString: The raw, complete JWT string.
// - publicKeyPEM: The public key corresponding to the private key used for signing, in PEM format.
// - alg: The algorithm specified in the JWT header (e.g., "RS256", "ES384").
//
// Returns:
// - nil if the signature is valid.
// - An error if the token format is invalid, decoding fails, key parsing fails,
// the algorithm is unsupported, or the signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
+326 -91
View File
@@ -17,7 +17,13 @@ import (
"golang.org/x/time/rate"
)
// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC
// createDefaultHTTPClient creates a new http.Client with settings optimized for OIDC communication.
// It configures the transport with specific timeouts (dial, keepalive, TLS handshake, idle connection),
// connection limits (max idle, max per host), enables HTTP/2, and sets a default request timeout.
// It also configures redirect handling to follow redirects up to a limit.
//
// Returns:
// - A pointer to the configured http.Client.
func createDefaultHTTPClient() *http.Client {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -128,16 +134,20 @@ var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
// VerifyToken implements the TokenVerifier interface to verify an OIDC token.
// It performs a complete verification process including:
// 1. Checking the token cache to avoid redundant verifications
// 2. Performing rate limiting and blacklist checks
// 3. Parsing the JWT structure
// 4. Verifying the JWT signature against the JWKS from the provider
// 5. Validating standard JWT claims (iss, aud, exp, etc.)
// 6. Caching the verified token for future requests
// VerifyToken implements the TokenVerifier interface. It performs a comprehensive validation of an ID token:
// 1. Checks the token cache; returns nil immediately if a valid cached entry exists.
// 2. Performs pre-verification checks (rate limiting, blacklist).
// 3. Parses the raw token string into a JWT struct.
// 4. Verifies the JWT signature and standard claims (iss, aud, exp, iat, nbf, sub) using VerifyJWTSignatureAndClaims.
// 5. If verification succeeds, caches the token claims until the token's expiration time.
// 6. If verification succeeds and the token has a JTI claim, adds the JTI to the blacklist cache to prevent replay attacks.
//
// Returns nil if the token is valid, or an error describing the validation failure.
// Parameters:
// - token: The raw ID token string to verify.
//
// Returns:
// - nil if the token is valid according to all checks.
// - An error describing the reason for validation failure (e.g., rate limit, blacklisted, parsing error, signature error, claim error).
func (t *TraefikOidc) VerifyToken(token string) error {
// Check cache first
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
@@ -192,7 +202,16 @@ func (t *TraefikOidc) VerifyToken(token string) error {
return nil
}
// performPreVerificationChecks performs rate limiting and blacklist checks
// performPreVerificationChecks executes preliminary checks before attempting full token validation.
// It enforces rate limiting using the configured limiter and checks if the raw token string
// or its JTI (if extractable) exists in the blacklist cache.
//
// Parameters:
// - token: The raw token string being verified.
//
// Returns:
// - nil if all pre-verification checks pass.
// - An error if the rate limit is exceeded or the token/JTI is blacklisted.
func (t *TraefikOidc) performPreVerificationChecks(token string) error {
// Enforce rate limiting
if !t.limiter.Allow() {
@@ -218,7 +237,13 @@ func (t *TraefikOidc) performPreVerificationChecks(token string) error {
return nil
}
// cacheVerifiedToken caches a verified token until its expiration time
// cacheVerifiedToken adds the claims of a successfully verified token to the token cache.
// It calculates the remaining duration until the token's 'exp' claim and uses that
// duration for the cache entry's lifetime.
//
// Parameters:
// - token: The raw token string (used as the cache key).
// - claims: The map of claims extracted from the verified token.
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
expirationTime := time.Unix(int64(claims["exp"].(float64)), 0)
now := time.Now()
@@ -226,7 +251,18 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
t.tokenCache.Set(token, claims, duration)
}
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
// VerifyJWTSignatureAndClaims implements the JWTVerifier interface. It verifies the signature
// of a parsed JWT against the provider's public keys obtained from the JWKS endpoint,
// and then validates the standard JWT claims (iss, aud, exp, iat, nbf, sub, jti replay).
//
// Parameters:
// - jwt: A pointer to the parsed JWT struct containing header and claims.
// - token: The original raw token string (used for signature verification).
//
// Returns:
// - nil if both the signature and all standard claims are valid.
// - An error describing the validation failure (e.g., failed to get JWKS, missing kid/alg,
// no matching key, signature verification failed, standard claim validation failed).
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT signature and claims")
@@ -277,24 +313,28 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
return nil
}
// New creates a new instance of the OIDC middleware.
// This is the main entry point for the middleware and is called by Traefik when loading the plugin.
// It initializes all components needed for OIDC authentication:
// - Session management for storing user state
// - Token caching and blacklisting
// - JWK caching for signature verification
// - Rate limiting to prevent abuse
// - Metadata discovery for OIDC provider endpoints
// New is the constructor for the TraefikOidc middleware plugin.
// It is called by Traefik during plugin initialization. It performs the following steps:
// 1. Creates a default configuration if none is provided.
// 2. Validates the session encryption key length.
// 3. Initializes the logger based on the configured log level.
// 4. Sets up the HTTP client (using defaults if none provided in config).
// 5. Creates the main TraefikOidc struct, populating fields from the config
// (paths, client details, PKCE/HTTPS flags, scopes, rate limiter, caches, allowed lists).
// 6. Initializes the SessionManager.
// 7. Sets up internal function pointers/interfaces (extractClaimsFunc, initiateAuthenticationFunc, tokenVerifier, jwtVerifier, tokenExchanger).
// 8. Adds default excluded URLs.
// 9. Starts background goroutines for token cache cleanup and OIDC provider metadata initialization/refresh.
//
// Parameters:
// - ctx: Context for initialization operations
// - next: The next handler in the middleware chain
// - config: Configuration options for the middleware
// - name: Identifier for this middleware instance
// - ctx: The context provided by Traefik for initialization.
// - next: The next http.Handler in the Traefik middleware chain.
// - config: The plugin configuration provided by the user in Traefik static/dynamic configuration.
// - name: The name assigned to this middleware instance by Traefik.
//
// Returns:
// - An http.Handler that implements the middleware
// - An error if initialization fails
// - An http.Handler (the TraefikOidc instance itself, which implements ServeHTTP).
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
if config == nil {
config = CreateConfig()
@@ -386,7 +426,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
return t, nil
}
// initializeMetadata discovers and initializes the provider metadata
// initializeMetadata asynchronously fetches and caches the OIDC provider metadata.
// It uses the MetadataCache to retrieve potentially cached data or fetch fresh data
// via discoverProviderMetadata. On successful retrieval, it updates the middleware's
// endpoint URLs (auth, token, jwks, etc.), starts the periodic metadata refresh goroutine,
// and signals completion by closing the initComplete channel. If fetching fails initially,
// it logs an error and the middleware might remain uninitialized until a successful refresh.
//
// Parameters:
// - providerURL: The base URL of the OIDC provider.
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Debug("Starting provider metadata discovery")
@@ -412,7 +460,12 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Error("Received nil metadata")
}
// updateMetadataEndpoints updates the middleware with metadata endpoints
// updateMetadataEndpoints updates the relevant endpoint URL fields (jwksURL, authURL, tokenURL, etc.)
// within the TraefikOidc instance based on the discovered provider metadata.
// This is called after successfully fetching or refreshing the metadata.
//
// Parameters:
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
@@ -422,7 +475,13 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
t.endSessionURL = metadata.EndSessionURL
}
// startMetadataRefresh periodically refreshes the OIDC metadata
// startMetadataRefresh starts a background goroutine that periodically attempts to refresh
// the OIDC provider metadata by calling GetMetadata on the metadataCache.
// It runs on a fixed ticker (currently 1 hour). Successful refreshes update the
// middleware's endpoint URLs via updateMetadataEndpoints. Fetch errors are logged.
//
// Parameters:
// - providerURL: The base URL of the OIDC provider, used for subsequent refresh attempts.
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
@@ -442,7 +501,19 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
}
}
// discoverProviderMetadata fetches the OIDC provider metadata
// discoverProviderMetadata attempts to fetch the OIDC provider's configuration from its
// well-known discovery endpoint (".well-known/openid-configuration").
// It implements an exponential backoff retry mechanism in case of transient network errors
// or provider unavailability during startup.
//
// Parameters:
// - providerURL: The base URL of the OIDC provider.
// - httpClient: The HTTP client to use for the request.
// - l: The logger instance for recording retries and errors.
//
// Returns:
// - A pointer to the fetched ProviderMetadata struct.
// - An error if fetching fails after all retries or if a timeout is exceeded.
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
@@ -481,7 +552,16 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
}
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
// fetchMetadata performs a single attempt to fetch and decode the OIDC provider metadata
// from the specified well-known configuration URL.
//
// Parameters:
// - wellKnownURL: The full URL to the ".well-known/openid-configuration" endpoint.
// - httpClient: The HTTP client to use for the GET request.
//
// Returns:
// - A pointer to the decoded ProviderMetadata struct.
// - An error if the GET request fails, the status code is not 200 OK, or JSON decoding fails.
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
resp, err := httpClient.Get(wellKnownURL)
if err != nil {
@@ -504,18 +584,22 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return &metadata, nil
}
// ServeHTTP is the main handler for the middleware that processes all HTTP requests.
// It implements the http.Handler interface and performs the following operations:
// 1. Waits for OIDC provider metadata initialization to complete
// 2. Checks if the requested URL is in the excluded list (bypassing authentication)
// 3. Retrieves or creates a user session
// 4. Handles special paths like callback and logout URLs
// 5. Verifies authentication status and token validity
// 6. Refreshes tokens that are about to expire
// 7. Validates user email domains, roles, and groups against configured restrictions
// 8. Sets appropriate headers for downstream services
// 9. Applies security headers to responses
// 10. Forwards the authenticated request to the next handler
// ServeHTTP is the main entry point for incoming requests to the middleware.
// It orchestrates the OIDC authentication flow:
// 1. Waits for initial OIDC metadata discovery to complete (with timeout).
// 2. Checks if the request path is excluded from authentication.
// 3. Checks if the request is for Server-Sent Events and bypasses if so.
// 4. Retrieves the user's session; initiates authentication if the session is invalid/missing.
// 5. Handles specific paths for OIDC callback (/callback) and logout (/logout).
// 6. Checks the user's authentication status using isUserAuthenticated (verifies token, checks expiry).
// 7. If the token is expired, handles it (initiates re-auth).
// 8. If the user is not authenticated, initiates authentication.
// 9. If the token needs proactive refresh (nearing expiry), attempts refreshToken. Handles refresh failure
// by returning 401 for API clients or initiating re-auth for browsers.
// 10. If authenticated and token is valid, performs authorization checks (allowed domain, roles/groups).
// 11. If authorized, sets user/token information in request headers (X-Forwarded-User, X-Auth-Request-*)
// and adds security headers (X-Frame-Options, etc.) to the response.
// 12. Forwards the request to the next handler in the chain.
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
@@ -698,8 +782,17 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.next.ServeHTTP(rw, req)
}
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
// handleExpiredToken is called when a user's session contains an expired token or
// when a token refresh attempt fails for a browser client.
// It clears the authentication-related data (tokens, email, authenticated flag) from the session,
// saves the cleared session, and then initiates a new authentication flow by calling
// defaultInitiateAuthentication, redirecting the user to the OIDC provider.
//
// Parameters:
// - rw: The HTTP response writer.
// - req: The HTTP request.
// - session: The user's session data containing the expired token information.
// - redirectURL: The callback URL to be used in the new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Clear authentication data but preserve CSRF state
session.SetAuthenticated(false)
@@ -717,9 +810,27 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// handleCallback processes the authentication callback from the OIDC provider.
// It validates the callback parameters, exchanges the authorization code for
// tokens, verifies the tokens, and establishes the user's session.
// handleCallback handles the request received at the OIDC callback URL (redirect_uri).
// It performs the following steps:
// 1. Retrieves the user session associated with the callback request.
// 2. Checks for error parameters returned by the OIDC provider.
// 3. Validates the 'state' parameter against the CSRF token stored in the session.
// 4. Extracts the authorization 'code' from the query parameters.
// 5. Retrieves the PKCE 'code_verifier' from the session (if PKCE is enabled).
// 6. Exchanges the authorization code for tokens using the TokenExchanger interface.
// 7. Verifies the received ID token's signature and standard claims using VerifyToken.
// 8. Extracts claims from the verified ID token.
// 9. Verifies the 'nonce' claim against the nonce stored in the session.
// 10. Validates the user's email domain against the allowed list.
// 11. If all checks pass, updates the session with authentication details (status, email, tokens).
// 12. Saves the updated session.
// 13. Redirects the user back to their original requested path (stored in session) or the root path.
// If any step fails, it sends an appropriate error response using sendErrorResponse.
//
// Parameters:
// - rw: The HTTP response writer.
// - req: The incoming HTTP request to the callback URL.
// - redirectURL: The fully qualified callback URL (used in the token exchange request).
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
@@ -846,7 +957,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// determineExcludedURL checks if the current request URL is in the excluded list
// determineExcludedURL checks if the provided request path matches any of the configured excluded URL prefixes.
//
// Parameters:
// - currentRequest: The path part of the incoming request URL.
//
// Returns:
// - true if the path starts with any of the prefixes in the t.excludedURLs map.
// - false otherwise.
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
@@ -858,7 +976,15 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
return false
}
// determineScheme determines the scheme (http or https) of the request
// determineScheme determines the request scheme (http or https).
// It prioritizes the X-Forwarded-Proto header if present, otherwise checks
// the TLS property of the request. Defaults to "http".
//
// Parameters:
// - req: The incoming HTTP request.
//
// Returns:
// - "https" or "http".
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
@@ -869,7 +995,14 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string {
return "http"
}
// determineHost determines the host of the request
// determineHost determines the request host.
// It prioritizes the X-Forwarded-Host header if present, otherwise uses the req.Host value.
//
// Parameters:
// - req: The incoming HTTP request.
//
// Returns:
// - The determined host string (e.g., "example.com:8080").
func (t *TraefikOidc) determineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
@@ -877,17 +1010,18 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
return req.Host
}
// isUserAuthenticated checks if the user is authenticated by validating their session and token.
// It performs a comprehensive check of the authentication state including:
// 1. Verifying the session's authenticated flag
// 2. Checking for the presence of an access token
// 3. Validating the token's signature and claims
// 4. Checking the token's expiration time
// isUserAuthenticated checks the authentication status based on the provided session data.
// It verifies the session's authenticated flag, the presence and validity of the access token (ID token),
// including signature and standard claims (using VerifyJWTSignatureAndClaims). It also checks if the
// token is within the configured refreshGracePeriod before its actual expiration.
//
// Returns three boolean values:
// - authenticated: Whether the user is currently authenticated
// - needsRefresh: Whether the token is valid but will expire soon (within grace period)
// - expired: Whether the token has expired or is otherwise invalid
// Parameters:
// - session: The SessionData object for the current user.
//
// Returns:
// - authenticated (bool): True if the session is marked authenticated and the token is present and valid (signature/claims ok, not expired beyond grace).
// - needsRefresh (bool): True if the token is valid but nearing expiration (within refreshGracePeriod) OR if VerifyJWTSignatureAndClaims failed specifically due to expiration (meaning refresh might be possible).
// - expired (bool): True if the session is unauthenticated, the token is missing, or the token verification failed for reasons other than nearing/actual expiration (e.g., invalid signature, invalid claims).
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("User is not authenticated according to session")
@@ -945,19 +1079,18 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return true, false, false
}
// defaultInitiateAuthentication initiates the OIDC authentication process.
// This function prepares and starts a new authentication flow by:
// 1. Generating security tokens (CSRF token and nonce) to prevent attacks
// 2. Clearing any existing session data to avoid state conflicts
// 3. Storing the original request path to redirect back after authentication
// 4. Building the authorization URL with all required OIDC parameters
// 5. Redirecting the user to the OIDC provider's authorization endpoint
// defaultInitiateAuthentication handles the process of starting an OIDC authentication flow.
// It generates necessary security values (CSRF token, nonce, PKCE verifier/challenge if enabled),
// clears any potentially stale data from the current session, stores the new security values
// and the original request URI in the session, saves the session (setting cookies),
// builds the OIDC authorization endpoint URL with required parameters, and finally
// redirects the user's browser to that URL.
//
// Parameters:
// - rw: The HTTP response writer for sending the redirect
// - req: The original HTTP request that triggered authentication
// - session: The user's session data for storing authentication state
// - redirectURL: The callback URL where the OIDC provider will redirect after authentication
// - rw: The HTTP response writer used to send the redirect response.
// - req: The original incoming HTTP request that requires authentication.
// - session: The user's SessionData object (potentially new or cleared).
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.NewString()
@@ -1007,15 +1140,33 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
http.Redirect(rw, req, authURL, http.StatusFound)
}
// verifyToken verifies the token using the token verifier interface.
// This function delegates to the configured token verifier implementation,
// which by default is the TraefikOidc instance itself (implementing the VerifyToken method).
// This design allows for easy mocking in tests and potential future extension.
// verifyToken is a wrapper method that calls the VerifyToken method of the configured
// TokenVerifier interface (which defaults to the TraefikOidc instance itself).
// This primarily exists to facilitate testing and potential future extensions where
// token verification logic might be delegated differently.
//
// Parameters:
// - token: The raw token string to verify.
//
// Returns:
// - The result of calling t.tokenVerifier.VerifyToken(token).
func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
// buildAuthURL constructs the authentication URL with optional PKCE support
// buildAuthURL constructs the OIDC authorization endpoint URL with all necessary query parameters
// for initiating the authorization code flow. It includes client_id, response_type, redirect_uri,
// state, nonce, and optionally PKCE parameters (code_challenge, code_challenge_method) if enabled
// and a challenge is provided. It also includes configured scopes.
//
// Parameters:
// - redirectURL: The callback URL (redirect_uri).
// - state: The CSRF token.
// - nonce: The OIDC nonce.
// - codeChallenge: The PKCE code challenge (can be empty if PKCE is disabled or not used).
//
// Returns:
// - The fully constructed authorization URL string.
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
params := url.Values{}
params.Set("client_id", t.clientID)
@@ -1037,7 +1188,16 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
return t.buildURLWithParams(t.authURL, params)
}
// buildURLWithParams ensures a URL is absolute and appends query parameters
// buildURLWithParams takes a base URL and query parameters and constructs a full URL string.
// If the baseURL is relative (doesn't start with http/https), it prepends the scheme and host
// from the configured issuerURL. It then appends the encoded query parameters.
//
// Parameters:
// - baseURL: The base URL (can be absolute or relative to the issuer).
// - params: A url.Values map containing the query parameters to append.
//
// Returns:
// - The fully constructed URL string with appended query parameters.
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
// Ensure URL is absolute
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
@@ -1054,7 +1214,8 @@ func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) stri
return baseURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
// startTokenCleanup starts background goroutines for periodically cleaning up
// the token cache, token blacklist cache, and JWK cache using the autoCleanupRoutine helper.
func (t *TraefikOidc) startTokenCleanup() {
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
go func() {
@@ -1069,7 +1230,14 @@ func (t *TraefikOidc) startTokenCleanup() {
}()
}
// RevokeToken adds the token to the blacklist
// RevokeToken handles local revocation of a token.
// It removes the token from the validation cache (tokenCache) and adds the raw
// token string to the blacklist cache (tokenBlacklist) with a default expiration (24h).
// This prevents the token from being validated successfully even if it hasn't expired yet.
// Note: This does *not* revoke the token with the OIDC provider.
//
// Parameters:
// - token: The raw token string to revoke locally.
func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
@@ -1080,7 +1248,17 @@ func (t *TraefikOidc) RevokeToken(token string) {
t.tokenBlacklist.Set(token, true, time.Until(expiry))
}
// RevokeTokenWithProvider revokes the token with the provider
// RevokeTokenWithProvider attempts to revoke a token directly with the OIDC provider
// using the revocation endpoint specified in the provider metadata or configuration.
// It sends a POST request with the token, token_type_hint, client_id, and client_secret.
//
// Parameters:
// - token: The token (e.g., refresh token or access token) to revoke.
// - tokenType: The type hint for the token being revoked (e.g., "refresh_token").
//
// Returns:
// - nil if the revocation request is successful (provider returns 200 OK).
// - An error if the request fails or the provider returns a non-OK status.
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
t.logger.Debugf("Revoking token with provider")
@@ -1117,7 +1295,20 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
return nil
}
// refreshToken refreshes the user's token, protected by a mutex within the session.
// refreshToken attempts to use the refresh token stored in the session to obtain a new set of tokens.
// It acquires a mutex associated with the session to prevent concurrent refresh attempts for the same session.
// It retrieves the refresh token, calls the TokenExchanger's GetNewTokenWithRefreshToken method,
// verifies the newly obtained ID token using verifyToken, updates the session with the new tokens,
// and saves the session.
//
// Parameters:
// - rw: The HTTP response writer (needed for saving the updated session).
// - req: The HTTP request (needed for saving the updated session).
// - session: The user's SessionData object containing the refresh token.
//
// Returns:
// - true if the token refresh was successful and the session was updated.
// - false if no refresh token was found, the refresh exchange failed, the new token failed verification, or saving the session failed.
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
// Lock the mutex specific to this session instance before attempting refresh
session.refreshMutex.Lock()
@@ -1159,7 +1350,16 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return true
}
// isAllowedDomain checks if the user's email domain is allowed
// isAllowedDomain checks if the domain part of the provided email address is present
// in the configured list of allowed domains (t.allowedUserDomains).
// If the allowed domains list is empty, all domains are considered allowed.
//
// Parameters:
// - email: The email address to check.
//
// Returns:
// - true if the domain is allowed or if no domain restrictions are configured.
// - false if the email format is invalid or the domain is not in the allowed list.
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true // If no domains are specified, all are allowed
@@ -1175,7 +1375,18 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
return ok
}
// extractGroupsAndRoles extracts groups and roles from the id_token
// extractGroupsAndRoles attempts to extract 'groups' and 'roles' claims from a decoded ID token.
// It expects these claims, if present, to be arrays of strings.
// It uses the configured extractClaimsFunc (which defaults to the package-level extractClaims)
// to get the claims map from the token string.
//
// Parameters:
// - idToken: The raw ID token string.
//
// Returns:
// - A slice of strings containing the groups found in the 'groups' claim.
// - A slice of strings containing the roles found in the 'roles' claim.
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present but not arrays of strings.
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
@@ -1216,7 +1427,17 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
return groups, roles, nil
}
// buildFullURL constructs a full URL from scheme, host and path
// buildFullURL constructs an absolute URL string from its components.
// If the provided path already starts with "http://" or "https://", it's returned directly.
// Otherwise, it combines the scheme, host, and path, ensuring the path starts with a '/'.
//
// Parameters:
// - scheme: The URL scheme ("http" or "https").
// - host: The host part of the URL (e.g., "example.com:8080").
// - path: The path part of the URL (e.g., "/resource").
//
// Returns:
// - The combined absolute URL string (e.g., "https://example.com:8080/resource").
func buildFullURL(scheme, host, path string) string {
// If the path is already a full URL, return it as-is
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
@@ -1233,21 +1454,35 @@ func buildFullURL(scheme, host, path string) string {
// --- TokenExchanger Interface Implementation ---
// ExchangeCodeForToken implements the TokenExchanger interface.
// It calls the existing exchangeTokens helper function.
// ExchangeCodeForToken provides the implementation for the TokenExchanger interface method.
// It directly calls the internal exchangeTokens method, passing through the arguments.
// This allows the TraefikOidc struct to act as its own default TokenExchanger, while
// still allowing mocking for tests.
func (t *TraefikOidc) ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Note: The original exchangeTokens helper is defined in helpers.go and is already a method on *TraefikOidc
return t.exchangeTokens(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
}
// GetNewTokenWithRefreshToken implements the TokenExchanger interface.
// It calls the existing getNewTokenWithRefreshToken helper function.
// GetNewTokenWithRefreshToken provides the implementation for the TokenExchanger interface method.
// It directly calls the internal getNewTokenWithRefreshToken helper method.
// This allows the TraefikOidc struct to act as its own default TokenExchanger, while
// still allowing mocking for tests.
func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
// Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc
return t.getNewTokenWithRefreshToken(refreshToken)
}
// sendErrorResponse sends an error response, adapting to the client's Accept header.
// sendErrorResponse sends an error response to the client, adapting the format based
// on the request's Accept header. If the client prefers "application/json", it sends
// a JSON object with "error", "error_description", and "status_code" fields.
// Otherwise, it sends a basic HTML error page containing the message and a link
// back to the application root or the original incoming path (if available from the session).
//
// Parameters:
// - rw: The HTTP response writer.
// - req: The HTTP request (used to check Accept header and potentially get session).
// - message: The error message to display/include in the response.
// - code: The HTTP status code to set for the response.
func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
acceptHeader := req.Header.Get("Accept")
+301 -5
View File
@@ -376,6 +376,7 @@ func TestServeHTTP(t *testing.T) {
setupSession func(*SessionData)
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks
requestHeaders map[string]string // Added for setting headers like Accept
}{
{
name: "Excluded URL",
@@ -394,7 +395,13 @@ func TestServeHTTP(t *testing.T) {
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token) // Use the valid token generated in Setup
// Generate a fresh valid token for this test case to avoid replay issues
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetRefreshToken("valid-refresh-token")
},
expectedStatus: http.StatusOK,
@@ -450,14 +457,161 @@ func TestServeHTTP(t *testing.T) {
},
{
name: "Logout URL",
requestPath: "/logout", // Assuming logout path is configured or defaulted correctly
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
},
expectedStatus: http.StatusFound, // Expect redirect after logout
expectedBody: "",
// No specific session assertion needed for logout redirect itself
},
{
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken())
session.SetRefreshToken("valid-refresh-token")
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
// Simulate failed refresh
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
}
},
requestHeaders: map[string]string{
"Accept": "application/json",
},
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client
expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`,
},
{
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken())
session.SetRefreshToken("valid-refresh-token")
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
// Simulate failed refresh
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
}
},
requestHeaders: map[string]string{
"Accept": "text/html", // Browser client
},
expectedStatus: http.StatusFound, // Expect redirect for browser client
},
{
name: "Authenticated request with token nearing expiry (needs refresh)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
// Create token expiring soon (e.g., 30s, within default 60s grace period)
exp := time.Now().Add(30 * time.Second).Unix()
iat := time.Now().Add(-1 * time.Minute).Unix()
nbf := time.Now().Add(-1 * time.Minute).Unix()
nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(nearExpiryToken)
session.SetRefreshToken("valid-refresh-token-for-near-expiry")
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
if refreshToken != "valid-refresh-token-for-near-expiry" {
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
}
// Simulate successful refresh
newToken := createNewValidToken()
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-near-expiry", ExpiresIn: 3600}, nil
}
},
expectedStatus: http.StatusOK, // Expect success after proactive refresh
expectedBody: "OK",
},
{
name: "Authenticated request with token valid (outside grace period)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
// Create token expiring later (e.g., 10 mins, outside default 60s grace period)
exp := time.Now().Add(10 * time.Minute).Unix()
iat := time.Now().Add(-1 * time.Minute).Unix()
nbf := time.Now().Add(-1 * time.Minute).Unix()
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(validToken)
session.SetRefreshToken("should-not-be-used-refresh-token")
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
// This should NOT be called
return func(refreshToken string) (*TokenResponse, error) {
t.Errorf("Refresh token function was called unexpectedly for valid token outside grace period")
return nil, fmt.Errorf("refresh should not have been attempted")
}
},
expectedStatus: http.StatusOK, // Expect success, no refresh needed
expectedBody: "OK",
},
{
name: "Disallowed Domain (Accept: JSON)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetRefreshToken("valid-refresh-token")
},
requestHeaders: map[string]string{
"Accept": "application/json",
},
expectedStatus: http.StatusForbidden,
expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`,
},
{
name: "Disallowed Domain (Accept: HTML)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetRefreshToken("valid-refresh-token")
},
requestHeaders: map[string]string{
"Accept": "text/html",
},
expectedStatus: http.StatusForbidden, // Still Forbidden, but HTML response
expectedBody: "", // Body check is harder for HTML, focus on status and content-type
},
}
@@ -468,6 +622,12 @@ func TestServeHTTP(t *testing.T) {
req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that
req.Header.Set("X-Forwarded-Host", "testhost.com")
req.Host = "testhost.com" // Also set Host header
// Set request headers from test case
if tc.requestHeaders != nil {
for key, value := range tc.requestHeaders {
req.Header.Set(key, value)
}
}
rr := httptest.NewRecorder()
@@ -527,10 +687,19 @@ func TestServeHTTP(t *testing.T) {
}
// Check response body if expected
// Check response body if expected (handle JSON vs HTML)
if tc.expectedBody != "" {
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
// For JSON, compare directly
if strings.Contains(rr.Header().Get("Content-Type"), "application/json") {
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Test %s: Expected JSON body %q, got %q", tc.name, tc.expectedBody, body)
}
} else if tc.expectedBody == "OK" { // Simple check for the "OK" body from next handler
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
}
}
// Add more sophisticated HTML body checks if needed
}
// Perform post-request session assertions if defined
@@ -2319,3 +2488,130 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
}
}
// TestVerifyTimeConstraint tests the time constraint verification logic with separate past/future skew tolerances.
func TestVerifyTimeConstraint(t *testing.T) {
// Define tolerances used in jwt.go (ensure they match)
toleranceFuture := 2 * time.Minute
tolerancePast := 10 * time.Second
now := time.Now()
tests := []struct {
name string
claimTime time.Time
claimName string
futureCheck bool // true for exp, false for iat/nbf
expectError bool
}{
// Expiration (future=true, tolerance=2min)
{
name: "EXP: Valid (expires in 1 min)",
claimTime: now.Add(1 * time.Minute),
claimName: "exp",
futureCheck: true,
expectError: false,
},
{
name: "EXP: Expired (expired 3 min ago)",
claimTime: now.Add(-3 * time.Minute), // Outside 2min tolerance
claimName: "exp",
futureCheck: true,
expectError: true,
},
{
name: "EXP: Valid (expired 1 min ago, within 2min tolerance)",
claimTime: now.Add(-1 * time.Minute), // Inside 2min tolerance
claimName: "exp",
futureCheck: true,
expectError: false, // Should be allowed due to future tolerance
},
// Issued At (future=false, tolerance=10s)
{
name: "IAT: Valid (issued 1 min ago)",
claimTime: now.Add(-1 * time.Minute),
claimName: "iat",
futureCheck: false,
expectError: false,
},
{
name: "IAT: Invalid (issued 15 sec in future)",
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
claimName: "iat",
futureCheck: false,
expectError: true, // "token used before issued"
},
{
name: "IAT: Valid (issued 5 sec in future, within 10s tolerance)",
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
claimName: "iat",
futureCheck: false,
expectError: false, // Should be allowed due to past tolerance
},
// Not Before (future=false, tolerance=10s)
{
name: "NBF: Valid (active 1 min ago)",
claimTime: now.Add(-1 * time.Minute),
claimName: "nbf",
futureCheck: false,
expectError: false,
},
{
name: "NBF: Invalid (active in 15 sec)",
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
claimName: "nbf",
futureCheck: false,
expectError: true, // "token not yet valid"
},
{
name: "NBF: Valid (active in 5 sec, within 10s tolerance)",
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
claimName: "nbf",
futureCheck: false,
expectError: false, // Should be allowed due to past tolerance
},
}
// Temporarily adjust global tolerances for test consistency, then restore
originalFutureTolerance := ClockSkewToleranceFuture
originalPastTolerance := ClockSkewTolerancePast
ClockSkewToleranceFuture = toleranceFuture
ClockSkewTolerancePast = tolerancePast
defer func() {
ClockSkewToleranceFuture = originalFutureTolerance
ClockSkewTolerancePast = originalPastTolerance
}()
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Convert claim time to float64 unix timestamp
unixTime := float64(tc.claimTime.Unix()) + float64(tc.claimTime.Nanosecond())/1e9
var err error
// Call the specific verification function which uses verifyTimeConstraint
if tc.claimName == "exp" {
err = verifyExpiration(unixTime)
} else if tc.claimName == "iat" {
err = verifyIssuedAt(unixTime)
} else if tc.claimName == "nbf" {
err = verifyNotBefore(unixTime)
} else {
t.Fatalf("Unknown claim name in test setup: %s", tc.claimName)
}
if tc.expectError {
if err == nil {
t.Errorf("Expected error for claim %s at time %v (now=%v), but got nil", tc.claimName, tc.claimTime, now)
} else {
t.Logf("Got expected error: %v", err) // Log the error for confirmation
}
} else {
if err != nil {
t.Errorf("Expected no error for claim %s at time %v (now=%v), but got: %v", tc.claimName, tc.claimTime, now, err)
}
}
})
}
} // Add missing closing brace for TestVerifyTimeConstraint
+28 -2
View File
@@ -15,6 +15,8 @@ type MetadataCache struct {
stopCleanup chan struct{}
}
// NewMetadataCache creates a new MetadataCache instance.
// It initializes the cache structure and starts the background cleanup goroutine.
func NewMetadataCache() *MetadataCache {
c := &MetadataCache{
autoCleanupInterval: 5 * time.Minute,
@@ -24,7 +26,8 @@ func NewMetadataCache() *MetadataCache {
return c
}
// Cleanup removes expired metadata from the cache.
// Cleanup removes the cached provider metadata if it has expired.
// This is called periodically by the auto-cleanup goroutine.
func (c *MetadataCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -35,11 +38,31 @@ func (c *MetadataCache) Cleanup() {
}
}
// isCacheValid checks if the cached metadata is present and has not expired.
// Note: This function assumes the read lock is held or it's called from a context
// where the lock is already held (like within GetMetadata after locking).
func (c *MetadataCache) isCacheValid() bool {
return c.metadata != nil && time.Now().Before(c.expiresAt)
}
// GetMetadata retrieves the metadata from cache or fetches it if expired
// GetMetadata retrieves the OIDC provider metadata.
// It first checks the cache for valid, non-expired metadata. If found, it's returned immediately.
// If the cache is empty or expired, it attempts to fetch the metadata from the provider's
// well-known endpoint using discoverProviderMetadata.
// If fetching is successful, the new metadata is cached for 1 hour.
// If fetching fails but valid metadata exists in the cache (even if expired), the cache expiry
// is extended by 5 minutes, and the cached data is returned to prevent thundering herd issues.
// If fetching fails and there's no cached data, an error is returned.
// It employs double-checked locking for thread safety and performance.
//
// Parameters:
// - providerURL: The base URL of the OIDC provider.
// - httpClient: The HTTP client to use for fetching metadata.
// - logger: The logger instance for recording errors or warnings.
//
// Returns:
// - A pointer to the ProviderMetadata struct.
// - An error if metadata cannot be retrieved from cache or fetched from the provider.
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
c.mutex.RLock()
if c.isCacheValid() {
@@ -76,10 +99,13 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
return metadata, nil
}
// startAutoCleanup starts the background goroutine that periodically calls Cleanup
// to remove expired metadata from the cache.
func (c *MetadataCache) startAutoCleanup() {
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
}
// Close stops the automatic cleanup goroutine associated with this metadata cache.
func (c *MetadataCache) Close() {
close(c.stopCleanup)
}
+186 -42
View File
@@ -16,8 +16,15 @@ import (
"github.com/gorilla/sessions"
)
// generateSecureRandomString creates a cryptographically secure random string of specified length.
// It returns the generated string or an error if random generation fails.
// generateSecureRandomString creates a cryptographically secure, hex-encoded random string.
// It reads the specified number of bytes from crypto/rand and encodes them as a hexadecimal string.
//
// Parameters:
// - length: The number of random bytes to generate (the resulting hex string will be twice this length).
//
// Returns:
// - A hex-encoded random string.
// - An error if reading random bytes fails.
func generateSecureRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
@@ -56,7 +63,14 @@ const (
minEncryptionKeyLength = 32
)
// compressToken compresses a token using gzip and base64 encodes it.
// compressToken compresses the input string using gzip and then encodes the result using standard base64 encoding.
// If any error occurs during compression, it returns the original uncompressed token as a fallback.
//
// Parameters:
// - token: The string to compress.
//
// Returns:
// - The base64 encoded, gzipped string, or the original string if compression fails.
func compressToken(token string) string {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
@@ -69,7 +83,15 @@ func compressToken(token string) string {
return base64.StdEncoding.EncodeToString(b.Bytes())
}
// decompressToken decompresses a base64 encoded gzipped token.
// decompressToken decodes a standard base64 encoded string and then decompresses the result using gzip.
// If base64 decoding or gzip decompression fails, it returns the original input string as a fallback,
// assuming it might not have been compressed.
//
// Parameters:
// - compressed: The base64 encoded, gzipped string.
//
// Returns:
// - The decompressed original string, or the input string if decompression fails.
func decompressToken(compressed string) string {
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
@@ -140,15 +162,15 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
return sm, nil
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS.
// getSessionOptions returns a sessions.Options struct configured with security best practices.
// It sets HttpOnly to true, Secure based on the request scheme or forceHTTPS setting,
// SameSite to LaxMode, MaxAge to the absoluteSessionTimeout, and Path to "/".
//
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled
// - Using SameSite=Lax for CSRF protection
// - Set with appropriate timeout and path settings
// Parameters:
// - isSecure: A boolean indicating if the current request context is secure (HTTPS).
//
// Returns:
// - A pointer to a configured sessions.Options struct.
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
@@ -210,11 +232,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
return sessionData, nil
}
// getTokenChunkSessions retrieves all session chunks for a given token type.
// getTokenChunkSessions retrieves all cookie chunks associated with a large token (access or refresh).
// It iteratively attempts to load cookies named "{baseName}_0", "{baseName}_1", etc., until
// a cookie is not found or returns an error. The loaded sessions are stored in the provided chunks map.
//
// Parameters:
// - r: The HTTP request
// - baseName: The base name for the token's session cookies
// - chunks: Map to store the chunks in
// - r: The incoming HTTP request containing the cookies.
// - baseName: The base name of the cookie (e.g., accessTokenCookie).
// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks) to populate with the found session chunks.
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
@@ -258,10 +283,16 @@ type SessionData struct {
refreshMutex sync.Mutex
}
// Save persists all session data to cookies in the HTTP response.
// It saves the main session, token sessions, and any token chunks,
// applying appropriate security options to each cookie. All cookies
// are saved with consistent security settings based on the request scheme.
// Save persists all parts of the session (main, access token, refresh token, and any chunks)
// back to the client as cookies in the HTTP response. It applies secure cookie options
// obtained via getSessionOptions based on the request's security context.
//
// Parameters:
// - r: The original HTTP request (used to determine security context for cookie options).
// - w: The HTTP response writer to which the Set-Cookie headers will be added.
//
// Returns:
// - An error if saving any of the session components fails.
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
@@ -305,7 +336,19 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
return nil
}
// Clear removes all session data by expiring all cookies and clearing their values.
// Clear removes all session data associated with this SessionData instance.
// It clears the values map of the main, access, and refresh sessions, sets their MaxAge to -1
// to expire the cookies immediately, and clears any associated token chunk cookies.
// If a ResponseWriter is provided, it attempts to save the expired sessions to send the
// expiring Set-Cookie headers. Finally, it clears internal fields and returns the SessionData
// object to the pool.
//
// Parameters:
// - r: The HTTP request (required by the underlying session store).
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
//
// Returns:
// - An error if saving the expired sessions fails (only if w is not nil).
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions.
sd.mainSession.Options.MaxAge = -1
@@ -340,7 +383,12 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
return err
}
// clearTokenChunks removes all session chunks for a given token type.
// clearTokenChunks iterates through a map of session chunks, clears their values,
// and sets their MaxAge to -1 to expire them. This is used internally by Clear.
//
// Parameters:
// - r: The HTTP request (required by the underlying session store, though not directly used here).
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
@@ -350,7 +398,12 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
}
}
// GetAuthenticated returns whether the current session is authenticated.
// GetAuthenticated checks if the session is marked as authenticated and has not exceeded
// the absolute session timeout.
//
// Returns:
// - true if the "authenticated" flag is set to true and the session creation time is within the allowed timeout.
// - false otherwise.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
if !auth {
@@ -365,8 +418,15 @@ func (sd *SessionData) GetAuthenticated() bool {
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
}
// SetAuthenticated updates the session's authentication status and rotates session ID.
// Returns an error if generating a new session ID fails.
// SetAuthenticated sets the authentication status of the session.
// If setting to true, it generates a new secure session ID for the main session
// to prevent session fixation attacks and records the current time as the creation time.
//
// Parameters:
// - value: The boolean authentication status (true for authenticated, false otherwise).
//
// Returns:
// - An error if generating a new session ID fails when setting value to true.
func (sd *SessionData) SetAuthenticated(value bool) error {
if value {
id, err := generateSecureRandomString(32)
@@ -380,7 +440,12 @@ func (sd *SessionData) SetAuthenticated(value bool) error {
return nil
}
// GetAccessToken retrieves the complete access token from the session.
// GetAccessToken retrieves the access token stored in the session.
// It handles reassembling the token from multiple cookie chunks if necessary
// and decompresses it if it was stored compressed.
//
// Returns:
// - The complete, decompressed access token string, or an empty string if not found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
@@ -414,7 +479,14 @@ func (sd *SessionData) GetAccessToken() string {
return token
}
// SetAccessToken stores the access token in the session.
// SetAccessToken stores the provided access token in the session.
// It first expires any existing access token chunk cookies.
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
// it's stored directly in the primary access token session. Otherwise, the compressed token
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_a_0, _oidc_raczylo_a_1, etc.).
//
// Parameters:
// - token: The access token string to store.
func (sd *SessionData) SetAccessToken(token string) {
// Expire any existing chunk cookies first.
if sd.request != nil {
@@ -444,7 +516,12 @@ func (sd *SessionData) SetAccessToken(token string) {
}
}
// GetRefreshToken retrieves the complete refresh token from the session.
// GetRefreshToken retrieves the refresh token stored in the session.
// It handles reassembling the token from multiple cookie chunks if necessary
// and decompresses it if it was stored compressed.
//
// Returns:
// - The complete, decompressed refresh token string, or an empty string if not found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
@@ -478,7 +555,14 @@ func (sd *SessionData) GetRefreshToken() string {
return token
}
// SetRefreshToken stores the refresh token in the session.
// SetRefreshToken stores the provided refresh token in the session.
// It first expires any existing refresh token chunk cookies.
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
// it's stored directly in the primary refresh token session. Otherwise, the compressed token
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_r_0, _oidc_raczylo_r_1, etc.).
//
// Parameters:
// - token: The refresh token string to store.
func (sd *SessionData) SetRefreshToken(token string) {
// Expire any existing chunk cookies first.
if sd.request != nil {
@@ -508,7 +592,13 @@ func (sd *SessionData) SetRefreshToken(token string) {
}
}
// expireAccessTokenChunks expires any existing access token chunk cookies.
// expireAccessTokenChunks finds all existing access token chunk cookies (_oidc_raczylo_a_N)
// associated with the current request, clears their values, and sets their MaxAge to -1.
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
// the expiring Set-Cookie headers. This is used internally when setting a new access token.
//
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
@@ -526,7 +616,13 @@ func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
}
}
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
// expireRefreshTokenChunks finds all existing refresh token chunk cookies (_oidc_raczylo_r_N)
// associated with the current request, clears their values, and sets their MaxAge to -1.
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
// the expiring Set-Cookie headers. This is used internally when setting a new refresh token.
//
// Parameters:
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
@@ -544,7 +640,15 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
}
}
// splitIntoChunks splits a string into chunks of specified size.
// splitIntoChunks divides a string `s` into a slice of strings, where each element
// has a maximum length of `chunkSize`.
//
// Parameters:
// - s: The string to split.
// - chunkSize: The maximum size of each chunk.
//
// Returns:
// - A slice of strings representing the chunks.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
@@ -559,57 +663,97 @@ func splitIntoChunks(s string, chunkSize int) []string {
return chunks
}
// GetCSRF retrieves the CSRF token from the session.
// GetCSRF retrieves the Cross-Site Request Forgery (CSRF) token stored in the main session.
//
// Returns:
// - The CSRF token string, or an empty string if not set.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// SetCSRF stores the provided CSRF token string in the main session.
// This token is typically generated at the start of the authentication flow.
//
// Parameters:
// - token: The CSRF token to store.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// GetNonce retrieves the OIDC nonce value stored in the main session.
// The nonce is used to associate an ID token with the specific authentication request.
//
// Returns:
// - The nonce string, or an empty string if not set.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// SetNonce stores the provided OIDC nonce string in the main session.
// This nonce is typically generated at the start of the authentication flow.
//
// Parameters:
// - nonce: The nonce string to store.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetCodeVerifier retrieves the PKCE code verifier from the session.
// GetCodeVerifier retrieves the PKCE (Proof Key for Code Exchange) code verifier
// stored in the main session. This is only relevant if PKCE is enabled.
//
// Returns:
// - The code verifier string, or an empty string if not set or PKCE is disabled.
func (sd *SessionData) GetCodeVerifier() string {
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
return codeVerifier
}
// SetCodeVerifier stores the PKCE code verifier in the session.
// SetCodeVerifier stores the provided PKCE code verifier string in the main session.
// This is typically called at the start of the authentication flow if PKCE is enabled.
//
// Parameters:
// - codeVerifier: The PKCE code verifier string to store.
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
sd.mainSession.Values["code_verifier"] = codeVerifier
}
// GetEmail retrieves the authenticated user's email address from the session.
// GetEmail retrieves the authenticated user's email address stored in the main session.
// This is typically extracted from the ID token claims after successful authentication.
//
// Returns:
// - The user's email address string, or an empty string if not set.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// SetEmail stores the provided user email address string in the main session.
// This is typically called after successful authentication and claim extraction.
//
// Parameters:
// - email: The user's email address to store.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
// GetIncomingPath retrieves the original request URI (including query parameters)
// that the user was trying to access before being redirected for authentication.
// This is stored in the main session to allow redirection back after successful login.
//
// Returns:
// - The original request URI string, or an empty string if not set.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered the authentication flow.
// SetIncomingPath stores the original request URI (path and query parameters)
// in the main session. This is typically called at the start of the authentication flow.
//
// Parameters:
// - path: The original request URI string (e.g., "/protected/resource?id=123").
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
+91 -33
View File
@@ -114,6 +114,14 @@ const (
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
// - EnablePKCE: false (PKCE is opt-in)
//
// CreateConfig initializes a new Config struct with default values for optional fields.
// It sets default scopes, log level, rate limit, enables ForceHTTPS, and sets the
// default refresh grace period. Required fields like ProviderURL, ClientID, ClientSecret,
// CallbackURL, and SessionEncryptionKey must be set explicitly after creation.
//
// Returns:
// - A pointer to a new Config struct with default settings applied.
func CreateConfig() *Config {
c := &Config{
Scopes: []string{"openid", "profile", "email"},
@@ -127,9 +135,14 @@ func CreateConfig() *Config {
return c
}
// Validate performs validation checks on the Config.
// It ensures all required fields are set and have valid values.
// Returns an error if any validation check fails.
// Validate checks the configuration settings for validity.
// It ensures that required fields (ProviderURL, CallbackURL, ClientID, ClientSecret, SessionEncryptionKey)
// are present and that URLs are well-formed (HTTPS where required). It also validates
// the session key length, log level, rate limit, and refresh grace period.
//
// Returns:
// - nil if the configuration is valid.
// - An error describing the first validation failure encountered.
func (c *Config) Validate() error {
// Validate provider URL
if c.ProviderURL == "" {
@@ -211,13 +224,26 @@ func (c *Config) Validate() error {
return nil
}
// isValidSecureURL checks if the provided string is a valid HTTPS URL
// isValidSecureURL checks if a given string represents a valid, absolute HTTPS URL.
// It uses url.Parse and checks for a nil error, an "https" scheme, and a non-empty host.
//
// Parameters:
// - s: The URL string to validate.
//
// Returns:
// - true if the string is a valid HTTPS URL, false otherwise.
func isValidSecureURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme == "https" && u.Host != ""
}
// isValidLogLevel checks if the provided log level is valid
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
//
// Parameters:
// - level: The log level string to validate.
//
// Returns:
// - true if the log level is valid, false otherwise.
func isValidLogLevel(level string) bool {
return level == "debug" || level == "info" || level == "error"
}
@@ -234,14 +260,20 @@ type Logger struct {
logDebug *log.Logger
}
// NewLogger creates a new Logger with the specified log level.
// The log level determines which messages are output:
// - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages
// - "error": Outputs only error messages
// NewLogger creates and configures a new Logger instance based on the provided log level.
// It initializes loggers for ERROR (stderr), INFO (stdout), and DEBUG (stdout) levels,
// enabling output based on the specified level:
// - "error": Only ERROR messages are output.
// - "info": INFO and ERROR messages are output.
// - "debug": DEBUG, INFO, and ERROR messages are output.
//
// Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled.
// If an invalid level is provided, it defaults to behavior similar to "error".
//
// Parameters:
// - logLevel: The desired logging level ("debug", "info", or "error").
//
// Returns:
// - A pointer to the configured Logger instance.
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
@@ -263,51 +295,77 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an informational message.
// These messages are intended for general operational information
// and are written to stdout.
// Info logs a message at the INFO level using Printf style formatting.
// Output is directed to stdout if the configured log level is "info" or "debug".
//
// Parameters:
// - format: The format string (as in fmt.Printf).
// - args: The arguments for the format string.
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
// Debug logs a message at the DEBUG level using Printf style formatting.
// Output is directed to stdout only if the configured log level is "debug".
//
// Parameters:
// - format: The format string (as in fmt.Printf).
// - args: The arguments for the format string.
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
// Error logs a message at the ERROR level using Printf style formatting.
// Output is always directed to stderr, regardless of the configured log level.
//
// Parameters:
// - format: The format string (as in fmt.Printf).
// - args: The arguments for the format string.
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an informational message using Printf formatting.
// These messages are intended for general operational information
// and are written to stdout.
// Infof logs a message at the INFO level using Printf style formatting.
// Equivalent to calling l.Info(format, args...).
// Output is directed to stdout if the configured log level is "info" or "debug".
//
// Parameters:
// - format: The format string (as in fmt.Printf).
// - args: The arguments for the format string.
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message using Printf formatting.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
// Debugf logs a message at the DEBUG level using Printf style formatting.
// Equivalent to calling l.Debug(format, args...).
// Output is directed to stdout only if the configured log level is "debug".
//
// Parameters:
// - format: The format string (as in fmt.Printf).
// - args: The arguments for the format string.
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message using Printf formatting.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
// Errorf logs a message at the ERROR level using Printf style formatting.
// Equivalent to calling l.Error(format, args...).
// Output is always directed to stderr, regardless of the configured log level.
//
// Parameters:
// - format: The format string (as in fmt.Printf).
// - args: The arguments for the format string.
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to both the HTTP response and the error log.
// It ensures consistent error handling across the middleware by logging the error
// and sending an appropriate HTTP response to the client.
// handleError logs an error message using the provided logger and sends an HTTP error
// response to the client with the specified message and status code.
//
// Parameters:
// - w: The http.ResponseWriter to send the error response to.
// - message: The error message string.
// - code: The HTTP status code for the response.
// - logger: The Logger instance to use for logging the error.
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)