From 1910cd6000fe52194ad816002dad9bda167138e0 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sat, 5 Apr 2025 11:31:45 +0100 Subject: [PATCH] Update documentation to the higher standards. --- autocleanup.go | 12 +- cache.go | 46 +++-- helpers.go | 167 ++++++++++++++----- jwk.go | 61 +++++++ jwt.go | 131 +++++++++++---- main.go | 417 ++++++++++++++++++++++++++++++++++++---------- main_test.go | 306 +++++++++++++++++++++++++++++++++- metadata_cache.go | 30 +++- session.go | 228 ++++++++++++++++++++----- settings.go | 124 ++++++++++---- 10 files changed, 1266 insertions(+), 256 deletions(-) diff --git a/autocleanup.go b/autocleanup.go index 9eb9b23..b18b752 100644 --- a/autocleanup.go +++ b/autocleanup.go @@ -2,8 +2,16 @@ package traefikoidc import "time" -// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval. -// It stops when a value is received on the stop channel. +// autoCleanupRoutine periodically calls the provided cleanup function. +// It starts a ticker with the given interval and executes the cleanup function +// on each tick. The routine stops gracefully when a signal is received on the +// stop channel. This is typically used for background cleanup tasks like +// expiring cache entries. +// +// Parameters: +// - interval: The time duration between cleanup calls. +// - stop: A channel used to signal the routine to stop. Receiving any value will terminate the loop. +// - cleanup: The function to call periodically for cleanup tasks. func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) { ticker := time.NewTicker(interval) defer ticker.Stop() diff --git a/cache.go b/cache.go index 12f6ab2..dad624c 100644 --- a/cache.go +++ b/cache.go @@ -46,7 +46,9 @@ type Cache struct { // DefaultMaxSize is the default maximum number of items in the cache. const DefaultMaxSize = 500 -// NewCache creates a new empty cache instance that is ready for use. +// NewCache creates a new empty cache instance with default settings. +// It initializes the internal maps and list, sets the default maximum size, +// and starts the automatic cleanup goroutine. func NewCache() *Cache { c := &Cache{ items: make(map[string]CacheItem, DefaultMaxSize), @@ -60,8 +62,12 @@ func NewCache() *Cache { return c } -// Set adds or updates an item in the cache with the specified expiration duration. -// It moves the item to the most recently used position. +// Set adds or updates an item in the cache with the specified key, value, and expiration duration. +// If the key already exists, its value and expiration time are updated, and it's moved +// to the most recently used position in the LRU list. +// If the key does not exist and the cache is full, the least recently used item is evicted +// before adding the new item. +// The expiration duration is relative to the time Set is called. func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { c.mutex.Lock() defer c.mutex.Unlock() @@ -95,8 +101,11 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { c.elems[key] = elem } -// Get retrieves an item from the cache if it exists and hasn't expired. -// Moving the accessed item to the most recently used position. +// Get retrieves an item from the cache by its key. +// If the item exists and has not expired, its value and true are returned. +// Accessing an item moves it to the most recently used position in the LRU list. +// If the item does not exist or has expired, nil and false are returned, and the +// expired item is removed from the cache. func (c *Cache) Get(key string) (interface{}, bool) { c.mutex.Lock() defer c.mutex.Unlock() @@ -120,7 +129,9 @@ func (c *Cache) Get(key string) (interface{}, bool) { return item.Value, true } -// Delete removes an item from the cache. +// Delete removes an item from the cache by its key. +// If the key exists, the corresponding item is removed from the cache storage +// and the LRU list. func (c *Cache) Delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() @@ -128,8 +139,10 @@ func (c *Cache) Delete(key string) { c.removeItem(key) } -// Cleanup removes all expired items from the cache. This should be called periodically -// to prevent memory bloat from expired entries. +// Cleanup iterates through the cache and removes all items that have expired. +// An item is considered expired if the current time is after its ExpiresAt timestamp. +// This method is called automatically by the auto-cleanup goroutine, but can also +// be called manually. func (c *Cache) Cleanup() { c.mutex.Lock() defer c.mutex.Unlock() @@ -143,7 +156,11 @@ func (c *Cache) Cleanup() { } } -// evictOldest removes the least recently used item from the cache. +// evictOldest removes the least recently used (oldest) item from the cache. +// It first attempts to find and remove an expired item from the front of the LRU list. +// If no expired items are found at the front, it removes the absolute oldest item (front of the list). +// This method is called internally by Set when the cache reaches its maximum size. +// Note: This function assumes the write lock is already held. func (c *Cache) evictOldest() { now := time.Now() elem := c.order.Front() @@ -167,7 +184,9 @@ func (c *Cache) evictOldest() { } } -// removeItem removes an item from both the cache and the LRU tracking structures. +// removeItem removes an item specified by the key from the cache's internal storage (items map) +// and its corresponding entry from the LRU list (order list and elems map). +// Note: This function assumes the write lock is already held. func (c *Cache) removeItem(key string) { delete(c.items, key) if elem, ok := c.elems[key]; ok { @@ -176,12 +195,15 @@ func (c *Cache) removeItem(key string) { } } -// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items. +// startAutoCleanup starts the background goroutine that automatically calls the Cleanup method +// at the interval specified by c.autoCleanupInterval. +// It uses the autoCleanupRoutine helper function. func (c *Cache) startAutoCleanup() { autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup) } -// Close terminates the auto cleanup goroutine. +// Close stops the automatic cleanup goroutine associated with this cache instance. +// It should be called when the cache is no longer needed to prevent resource leaks. func (c *Cache) Close() { close(c.stopCleanup) } diff --git a/helpers.go b/helpers.go index 9f9f64e..ea1f41d 100644 --- a/helpers.go +++ b/helpers.go @@ -15,10 +15,14 @@ import ( "time" ) -// generateNonce creates a cryptographically secure random nonce -// for use in the OIDC authentication flow. The nonce is used to -// prevent replay attacks by ensuring the token received matches -// the authentication request. +// generateNonce creates a cryptographically secure random string suitable for use as an OIDC nonce. +// The nonce is used during the authentication flow to mitigate replay attacks by associating +// the ID token with the specific authentication request. +// It generates 32 random bytes and encodes them using base64 URL encoding. +// +// Returns: +// - A base64 URL encoded random string (nonce). +// - An error if the random byte generation fails. func generateNonce() (string, error) { nonceBytes := make([]byte, 32) _, err := rand.Read(nonceBytes) @@ -28,9 +32,13 @@ func generateNonce() (string, error) { return base64.URLEncoding.EncodeToString(nonceBytes), nil } -// generateCodeVerifier creates a cryptographically secure random string -// for use as a PKCE code verifier. The code verifier must be between 43 and 128 -// characters long, per the PKCE spec (RFC 7636). +// generateCodeVerifier creates a cryptographically secure random string suitable for use as a PKCE code verifier. +// According to RFC 7636, the verifier should be a high-entropy string between 43 and 128 characters long. +// This function generates 32 random bytes, resulting in a 43-character base64 URL encoded string. +// +// Returns: +// - A base64 URL encoded random string (code verifier). +// - An error if the random byte generation fails. func generateCodeVerifier() (string, error) { // Using 32 bytes (256 bits) will produce a 43 character base64url string verifierBytes := make([]byte, 32) @@ -41,8 +49,15 @@ func generateCodeVerifier() (string, error) { return base64.RawURLEncoding.EncodeToString(verifierBytes), nil } -// deriveCodeChallenge creates a code challenge from a code verifier -// using the SHA-256 method as specified in the PKCE standard (RFC 7636). +// deriveCodeChallenge computes the PKCE code challenge from a given code verifier. +// It uses the S256 challenge method (SHA-256 hash followed by base64 URL encoding) +// as defined in RFC 7636. +// +// Parameters: +// - codeVerifier: The high-entropy string generated by generateCodeVerifier. +// +// Returns: +// - The base64 URL encoded SHA-256 hash of the code verifier (code challenge). func deriveCodeChallenge(codeVerifier string) string { // Calculate SHA-256 hash of the code verifier hasher := sha256.New() @@ -73,14 +88,22 @@ type TokenResponse struct { TokenType string `json:"token_type"` } -// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider. -// It supports both authorization code and refresh token grant types. +// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint. +// It handles both the "authorization_code" grant type (exchanging an authorization code for tokens) +// and the "refresh_token" grant type (using a refresh token to obtain new tokens). +// It includes necessary parameters like client credentials and handles PKCE verification if applicable. +// The function follows redirects and handles potential errors during the exchange. +// // Parameters: -// - ctx: Context for the HTTP request -// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token") -// - codeOrToken: Either the authorization code or refresh token -// - redirectURL: The callback URL for authorization code grant -// - codeVerifier: Optional PKCE code verifier for authorization code grant +// - ctx: The context for the outgoing HTTP request. +// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token"). +// - codeOrToken: The authorization code (for "authorization_code" grant) or the refresh token (for "refresh_token" grant). +// - redirectURL: The redirect URI that was used in the initial authorization request (required for "authorization_code" grant). +// - codeVerifier: The PKCE code verifier (required for "authorization_code" grant if PKCE was used). +// +// Returns: +// - A TokenResponse containing the obtained tokens (ID, access, refresh). +// - An error if the token exchange fails (e.g., network error, provider error, invalid grant). func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) { data := url.Values{ "grant_type": {grantType}, @@ -140,8 +163,16 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code return &tokenResponse, nil } -// getNewTokenWithRefreshToken obtains new tokens using a refresh token. -// This is used to refresh access tokens before they expire. +// getNewTokenWithRefreshToken uses a refresh token to obtain a new set of tokens (ID, access, refresh) +// from the OIDC provider's token endpoint. It wraps the exchangeTokens function with the +// "refresh_token" grant type. +// +// Parameters: +// - refreshToken: The refresh token previously obtained during authentication or a prior refresh. +// +// Returns: +// - A TokenResponse containing the newly obtained tokens. +// - An error if the refresh operation fails. func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { ctx := context.Background() tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "") @@ -153,8 +184,17 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe return tokenResponse, nil } -// extractClaims parses a JWT token and extracts its claims. -// It handles base64url decoding and JSON parsing of the token payload. +// extractClaims decodes the payload (claims set) part of a JWT string. +// It splits the JWT into its three parts, base64 URL decodes the second part (payload), +// and unmarshals the resulting JSON into a map. +// Note: This function does *not* validate the token's signature or claims. +// +// Parameters: +// - tokenString: The raw JWT string. +// +// Returns: +// - A map representing the JSON claims extracted from the token payload. +// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails. func extractClaims(tokenString string) (map[string]interface{}, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { @@ -182,21 +222,36 @@ type TokenCache struct { cache *Cache } -// NewTokenCache creates a new TokenCache instance. +// NewTokenCache creates and initializes a new TokenCache. +// It internally creates a new generic Cache instance for storage. func NewTokenCache() *TokenCache { return &TokenCache{ cache: NewCache(), } } -// Set stores a token's claims in the cache with an expiration time. +// Set stores the claims associated with a specific token string in the cache. +// It prefixes the token string to avoid potential collisions with other cache types +// and sets the provided expiration duration. +// +// Parameters: +// - token: The raw token string (used as the key). +// - claims: The map of claims associated with the token. +// - expiration: The duration for which the cache entry should be valid. func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) { token = "t-" + token tc.cache.Set(token, claims, expiration) } -// Get retrieves a token's claims from the cache. -// Returns the claims and a boolean indicating if the token was found. +// Get retrieves the cached claims for a given token string. +// It prefixes the token string before querying the underlying cache. +// +// Parameters: +// - token: The raw token string to look up. +// +// Returns: +// - The cached claims map if found and valid. +// - A boolean indicating whether the token was found in the cache (true if found, false otherwise). func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) { token = "t-" + token value, found := tc.cache.Get(token) @@ -207,20 +262,34 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) { return claims, ok } -// Delete removes a token from the cache. +// Delete removes the cached entry for a specific token string. +// It prefixes the token string before calling the underlying cache's Delete method. +// +// Parameters: +// - token: The raw token string to remove from the cache. func (tc *TokenCache) Delete(token string) { token = "t-" + token tc.cache.Delete(token) } -// Cleanup removes expired tokens from the cache. +// Cleanup triggers the cleanup process for the underlying generic cache, +// removing expired token entries. func (tc *TokenCache) Cleanup() { tc.cache.Cleanup() } -// exchangeCodeForToken exchanges an authorization code for tokens. -// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration. -// The code verifier is only included in the token request if PKCE is enabled. +// exchangeCodeForToken is a convenience function that wraps exchangeTokens specifically +// for the "authorization_code" grant type. It handles the conditional inclusion of the +// PKCE code verifier based on the middleware's configuration (t.enablePKCE). +// +// Parameters: +// - code: The authorization code received from the OIDC provider. +// - redirectURL: The redirect URI used in the initial authorization request. +// - codeVerifier: The PKCE code verifier stored in the session (if PKCE is enabled). +// +// Returns: +// - A TokenResponse containing the obtained tokens. +// - An error if the code exchange fails. func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { ctx := context.Background() @@ -237,8 +306,15 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, code return tokenResponse, nil } -// createStringMap creates a map from a slice of strings. -// Used for efficient lookups in allowed domains and roles. +// createStringMap converts a slice of strings into a map[string]struct{} (a set). +// This is useful for creating efficient lookups (O(1) average time complexity) +// for checking the presence of items like allowed domains, roles, or groups. +// +// Parameters: +// - keys: A slice of strings to be added to the set. +// +// Returns: +// - A map where the keys are the strings from the input slice and the values are empty structs. func createStringMap(keys []string) map[string]struct{} { result := make(map[string]struct{}) for _, key := range keys { @@ -247,9 +323,17 @@ func createStringMap(keys []string) map[string]struct{} { return result } -// handleLogout manages the OIDC logout process. -// It clears the session and redirects either to the OIDC provider's -// end session endpoint (if available) or to the configured post-logout URL. +// handleLogout processes requests to the configured logout path. +// It performs the following steps: +// 1. Retrieves the current user session. +// 2. Gets the access token (ID token hint) from the session. +// 3. Clears all authentication-related data from the session cookies. +// 4. Determines the final post-logout redirect URI. +// 5. If an OIDC end_session_endpoint is configured and an ID token hint is available, +// it builds the OIDC logout URL and redirects the user agent to the provider for logout. +// 6. Otherwise, it redirects the user agent directly to the post-logout redirect URI. +// +// It handles potential errors during session retrieval or clearing. func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { session, err := t.sessionManager.GetSession(req) if err != nil { @@ -291,11 +375,18 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound) } -// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters. +// BuildLogoutURL constructs the URL for redirecting the user agent to the OIDC provider's +// end_session_endpoint, including the required id_token_hint and optional +// post_logout_redirect_uri parameters as query arguments. +// // Parameters: -// - endSessionURL: The OIDC provider's end session endpoint -// - idToken: The ID token to be invalidated -// - postLogoutRedirectURI: Where to redirect after logout completes +// - endSessionURL: The URL of the OIDC provider's end session endpoint. +// - idToken: The ID token previously issued to the user (used as id_token_hint). +// - postLogoutRedirectURI: The optional URI where the provider should redirect the user agent after logout. +// +// Returns: +// - The fully constructed logout URL string. +// - An error if the provided endSessionURL is invalid. func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) { u, err := url.Parse(endSessionURL) if err != nil { diff --git a/jwk.go b/jwk.go index de5aa22..a3e44e8 100644 --- a/jwk.go +++ b/jwk.go @@ -45,6 +45,21 @@ type JWKCacheInterface interface { Cleanup() } +// GetJWKS retrieves the JSON Web Key Set (JWKS) from the cache or fetches it from the provider. +// It first checks if a valid, non-expired JWKS is present in the cache. If so, it returns the cached version. +// Otherwise, it attempts to fetch the JWKS from the specified jwksURL using the provided httpClient. +// If the fetch is successful, the JWKS is stored in the cache with an expiration time based on CacheLifetime +// (defaulting to 1 hour if not set) and returned. +// This method uses double-checked locking to minimize contention when the cache needs refreshing. +// +// Parameters: +// - ctx: Context for the HTTP request if fetching is required. +// - jwksURL: The URL of the OIDC provider's JWKS endpoint. +// - httpClient: The HTTP client to use for fetching the JWKS. +// +// Returns: +// - A pointer to the JWKSet containing the keys. +// - An error if fetching fails or the response cannot be decoded. func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { c.mutex.RLock() if c.jwks != nil && time.Now().Before(c.expiresAt) { @@ -74,6 +89,8 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http return jwks, nil } +// Cleanup removes the cached JWKS if it has expired. +// This is intended to be called periodically to ensure stale JWKS data is cleared. func (c *JWKCache) Cleanup() { c.mutex.Lock() defer c.mutex.Unlock() @@ -84,6 +101,17 @@ func (c *JWKCache) Cleanup() { } } +// fetchJWKS retrieves the JSON Web Key Set (JWKS) from the specified URL. +// It uses the provided context and HTTP client to make the request. +// +// Parameters: +// - ctx: Context for the HTTP request. +// - jwksURL: The URL of the OIDC provider's JWKS endpoint. +// - httpClient: The HTTP client to use for the request. +// +// Returns: +// - A pointer to the fetched JWKSet. +// - An error if the request fails, the status code is not OK, or the response body cannot be decoded. func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { // Create a request with context to enforce timeout req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil) @@ -109,6 +137,16 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J return &jwks, nil } +// jwkToPEM converts a JWK (JSON Web Key) object into PEM (Privacy-Enhanced Mail) format. +// It selects the appropriate conversion function based on the JWK's key type ("kty"). +// Currently supports "RSA" and "EC" key types. +// +// Parameters: +// - jwk: A pointer to the JWK object to convert. +// +// Returns: +// - A byte slice containing the public key in PEM format. +// - An error if the key type is unsupported or conversion fails. func jwkToPEM(jwk *JWK) ([]byte, error) { converter, ok := jwkConverters[jwk.Kty] if !ok { @@ -124,6 +162,17 @@ var jwkConverters = map[string]jwkToPEMConverter{ "EC": ecJWKToPEM, } +// rsaJWKToPEM converts an RSA JWK into PEM format. +// It decodes the modulus (n) and exponent (e) from base64 URL encoding, +// constructs an rsa.PublicKey, marshals it into PKIX format, and then +// encodes it as a PEM block. +// +// Parameters: +// - jwk: A pointer to the RSA JWK object (must have "kty": "RSA"). +// +// Returns: +// - A byte slice containing the RSA public key in PEM format. +// - An error if decoding parameters fails or key marshaling fails. func rsaJWKToPEM(jwk *JWK) ([]byte, error) { nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N) if err != nil { @@ -155,6 +204,18 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) { return pubKeyPEM, nil } +// ecJWKToPEM converts an EC (Elliptic Curve) JWK into PEM format. +// It decodes the X and Y coordinates from base64 URL encoding, determines the +// elliptic curve based on the "crv" parameter (P-256, P-384, P-521), +// constructs an ecdsa.PublicKey, marshals it into PKIX format, and then +// encodes it as a PEM block. +// +// Parameters: +// - jwk: A pointer to the EC JWK object (must have "kty": "EC"). +// +// Returns: +// - A byte slice containing the EC public key in PEM format. +// - An error if decoding parameters fails, the curve is unsupported, or key marshaling fails. func ecJWKToPEM(jwk *JWK) ([]byte, error) { xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X) if err != nil { diff --git a/jwt.go b/jwt.go index febab3d..9a8ab50 100644 --- a/jwt.go +++ b/jwt.go @@ -20,6 +20,10 @@ var ( replayCache = make(map[string]time.Time) ) +// cleanupReplayCache iterates through the replay cache and removes entries +// whose expiration time is before the current time. This function should be +// called periodically to prevent the cache from growing indefinitely. +// It acquires a mutex to ensure thread safety during cleanup. func cleanupReplayCache() { now := time.Now() for token, expiry := range replayCache { @@ -48,7 +52,18 @@ type JWT struct { Token string } -// parseJWT parses a JWT token string into a JWT struct. +// parseJWT decodes a raw JWT string into its constituent parts: header, claims, and signature. +// It splits the token string by '.', decodes each part using base64 URL decoding, +// and unmarshals the header and claims JSON into maps. The raw signature bytes are stored. +// It performs basic format validation (expecting 3 parts). +// Note: This function does *not* validate the signature or the claims. +// +// Parameters: +// - tokenString: The raw JWT string. +// +// Returns: +// - A pointer to a JWT struct containing the decoded parts. +// - An error if the token format is invalid or decoding/unmarshaling fails. func parseJWT(tokenString string) (*JWT, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { @@ -84,8 +99,24 @@ func parseJWT(tokenString string) (*JWT, error) { return jwt, nil } -// Verify validates the standard JWT claims as defined in RFC 7519. -// Verify validates the standard JWT claims as defined in RFC 7519. +// Verify performs standard claim validation on the JWT according to RFC 7519. +// It checks the following: +// - Algorithm ('alg') is supported. +// - Issuer ('iss') matches the expected issuerURL. +// - Audience ('aud') contains the expected clientID. +// - Expiration time ('exp') is in the future (within tolerance). +// - Issued at time ('iat') is in the past (within tolerance). +// - Not before time ('nbf'), if present, is in the past (within tolerance). +// - Subject ('sub') claim exists and is not empty. +// - JWT ID ('jti'), if present, is checked against a replay cache to prevent token reuse. +// +// Parameters: +// - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com"). +// - clientID: The expected audience value (the client ID of this application). +// +// Returns: +// - nil if all standard claims are valid. +// - An error describing the first validation failure encountered. func (j *JWT) Verify(issuerURL, clientID string) error { // Validate algorithm to prevent algorithm switching attacks alg, ok := j.Header["alg"].(string) @@ -175,6 +206,16 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return nil } +// verifyAudience checks if the expected audience is present in the token's 'aud' claim. +// The 'aud' claim can be a single string or an array of strings. +// +// Parameters: +// - tokenAudience: The 'aud' claim value extracted from the token (can be string or []interface{}). +// - expectedAudience: The audience value expected for this application (client ID). +// +// Returns: +// - nil if the expected audience is found. +// - An error if the claim type is invalid or the expected audience is not present. func verifyAudience(tokenAudience interface{}, expectedAudience string) error { switch aud := tokenAudience.(type) { case string: @@ -198,6 +239,15 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error { return nil } +// verifyIssuer checks if the token's 'iss' claim matches the expected issuer URL. +// +// Parameters: +// - tokenIssuer: The 'iss' claim value from the token. +// - expectedIssuer: The expected issuer URL configured for the OIDC provider. +// +// Returns: +// - nil if the issuers match. +// - An error if the issuers do not match. func verifyIssuer(tokenIssuer, expectedIssuer string) error { if tokenIssuer != expectedIssuer { return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer) @@ -205,57 +255,76 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error { return nil } -// verifyTimeConstraint is a generic function to verify time-based claims +// verifyTimeConstraint checks time-based claims ('exp', 'iat', 'nbf') against the current time, +// allowing for configurable clock skew. It uses different tolerances for past and future checks. +// +// Parameters: +// - unixTime: The timestamp value from the claim (as a float64 Unix time). +// - claimName: The name of the claim being verified ("exp", "iat", "nbf"). +// - future: A boolean indicating the direction of the check (true for 'exp', false for 'iat'/'nbf'). +// +// Returns: +// - nil if the time constraint is met within the allowed tolerance. +// - An error describing the failure (e.g., "token has expired", "token used before issued"). func verifyTimeConstraint(unixTime float64, claimName string, future bool) error { claimTime := time.Unix(int64(unixTime), 0) - now := time.Now().Truncate(time.Second) + now := time.Now() // Use current time without truncation - var skewedNow time.Time - if future { - // For expiration (future=true), add skew to now (making now later) - // Use the larger tolerance for future checks (exp) - skewedNow = now.Add(ClockSkewToleranceFuture) - } else { - // For iat/nbf (future=false), subtract skew from now (making now earlier) - // Use the smaller, specific tolerance for past checks (iat, nbf) - skewedNow = now.Add(-ClockSkewTolerancePast) // Subtract the past tolerance - } - - if claimTime.Equal(now) { - return nil - } - - // For expiration: if skewedNow (later) is after expiration, token expired - // For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid - if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) { - var reason string - if future { - reason = "has expired" - } else { + var err error + if future { // 'exp' check + // Token is expired if Now is after (ClaimTime + FutureTolerance) + allowedExpiry := claimTime.Add(ClockSkewToleranceFuture) + if now.After(allowedExpiry) { + err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC()) + } + } else { // 'iat' or 'nbf' check + // Token is invalid if Now is before (ClaimTime - PastTolerance) + allowedStart := claimTime.Add(-ClockSkewTolerancePast) + if now.Before(allowedStart) { + reason := "not yet valid" if claimName == "iat" { reason = "used before issued" - } else { - reason = "not yet valid" } + err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC()) } - return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC()) } - return nil + return err } +// verifyExpiration checks the 'exp' (Expiration Time) claim. +// It calls verifyTimeConstraint with future=true. func verifyExpiration(expiration float64) error { return verifyTimeConstraint(expiration, "exp", true) } +// verifyIssuedAt checks the 'iat' (Issued At) claim. +// It calls verifyTimeConstraint with future=false. func verifyIssuedAt(issuedAt float64) error { return verifyTimeConstraint(issuedAt, "iat", false) } +// verifyNotBefore checks the 'nbf' (Not Before) claim. +// It calls verifyTimeConstraint with future=false. func verifyNotBefore(notBefore float64) error { return verifyTimeConstraint(notBefore, "nbf", false) } +// verifySignature validates the JWT's signature using the provided public key. +// It parses the public key from PEM format, selects the appropriate hashing algorithm +// based on the 'alg' parameter (SHA256/384/512), hashes the token's signing input +// (header + "." + payload), and then verifies the signature against the hash using +// the corresponding RSA (PKCS1v15 or PSS) or ECDSA verification method. +// +// Parameters: +// - tokenString: The raw, complete JWT string. +// - publicKeyPEM: The public key corresponding to the private key used for signing, in PEM format. +// - alg: The algorithm specified in the JWT header (e.g., "RS256", "ES384"). +// +// Returns: +// - nil if the signature is valid. +// - An error if the token format is invalid, decoding fails, key parsing fails, +// the algorithm is unsupported, or the signature verification fails. func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { parts := strings.Split(tokenString, ".") if len(parts) != 3 { diff --git a/main.go b/main.go index 92c9536..9e77b12 100644 --- a/main.go +++ b/main.go @@ -17,7 +17,13 @@ import ( "golang.org/x/time/rate" ) -// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC +// createDefaultHTTPClient creates a new http.Client with settings optimized for OIDC communication. +// It configures the transport with specific timeouts (dial, keepalive, TLS handshake, idle connection), +// connection limits (max idle, max per host), enables HTTP/2, and sets a default request timeout. +// It also configures redirect handling to follow redirects up to a limit. +// +// Returns: +// - A pointer to the configured http.Client. func createDefaultHTTPClient() *http.Client { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -128,16 +134,20 @@ var defaultExcludedURLs = map[string]struct{}{ "/favicon": {}, } -// VerifyToken implements the TokenVerifier interface to verify an OIDC token. -// It performs a complete verification process including: -// 1. Checking the token cache to avoid redundant verifications -// 2. Performing rate limiting and blacklist checks -// 3. Parsing the JWT structure -// 4. Verifying the JWT signature against the JWKS from the provider -// 5. Validating standard JWT claims (iss, aud, exp, etc.) -// 6. Caching the verified token for future requests +// VerifyToken implements the TokenVerifier interface. It performs a comprehensive validation of an ID token: +// 1. Checks the token cache; returns nil immediately if a valid cached entry exists. +// 2. Performs pre-verification checks (rate limiting, blacklist). +// 3. Parses the raw token string into a JWT struct. +// 4. Verifies the JWT signature and standard claims (iss, aud, exp, iat, nbf, sub) using VerifyJWTSignatureAndClaims. +// 5. If verification succeeds, caches the token claims until the token's expiration time. +// 6. If verification succeeds and the token has a JTI claim, adds the JTI to the blacklist cache to prevent replay attacks. // -// Returns nil if the token is valid, or an error describing the validation failure. +// Parameters: +// - token: The raw ID token string to verify. +// +// Returns: +// - nil if the token is valid according to all checks. +// - An error describing the reason for validation failure (e.g., rate limit, blacklisted, parsing error, signature error, claim error). func (t *TraefikOidc) VerifyToken(token string) error { // Check cache first if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 { @@ -192,7 +202,16 @@ func (t *TraefikOidc) VerifyToken(token string) error { return nil } -// performPreVerificationChecks performs rate limiting and blacklist checks +// performPreVerificationChecks executes preliminary checks before attempting full token validation. +// It enforces rate limiting using the configured limiter and checks if the raw token string +// or its JTI (if extractable) exists in the blacklist cache. +// +// Parameters: +// - token: The raw token string being verified. +// +// Returns: +// - nil if all pre-verification checks pass. +// - An error if the rate limit is exceeded or the token/JTI is blacklisted. func (t *TraefikOidc) performPreVerificationChecks(token string) error { // Enforce rate limiting if !t.limiter.Allow() { @@ -218,7 +237,13 @@ func (t *TraefikOidc) performPreVerificationChecks(token string) error { return nil } -// cacheVerifiedToken caches a verified token until its expiration time +// cacheVerifiedToken adds the claims of a successfully verified token to the token cache. +// It calculates the remaining duration until the token's 'exp' claim and uses that +// duration for the cache entry's lifetime. +// +// Parameters: +// - token: The raw token string (used as the cache key). +// - claims: The map of claims extracted from the verified token. func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) { expirationTime := time.Unix(int64(claims["exp"].(float64)), 0) now := time.Now() @@ -226,7 +251,18 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa t.tokenCache.Set(token, claims, duration) } -// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims +// VerifyJWTSignatureAndClaims implements the JWTVerifier interface. It verifies the signature +// of a parsed JWT against the provider's public keys obtained from the JWKS endpoint, +// and then validates the standard JWT claims (iss, aud, exp, iat, nbf, sub, jti replay). +// +// Parameters: +// - jwt: A pointer to the parsed JWT struct containing header and claims. +// - token: The original raw token string (used for signature verification). +// +// Returns: +// - nil if both the signature and all standard claims are valid. +// - An error describing the validation failure (e.g., failed to get JWKS, missing kid/alg, +// no matching key, signature verification failed, standard claim validation failed). func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error { t.logger.Debugf("Verifying JWT signature and claims") @@ -277,24 +313,28 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error return nil } -// New creates a new instance of the OIDC middleware. -// This is the main entry point for the middleware and is called by Traefik when loading the plugin. -// It initializes all components needed for OIDC authentication: -// - Session management for storing user state -// - Token caching and blacklisting -// - JWK caching for signature verification -// - Rate limiting to prevent abuse -// - Metadata discovery for OIDC provider endpoints +// New is the constructor for the TraefikOidc middleware plugin. +// It is called by Traefik during plugin initialization. It performs the following steps: +// 1. Creates a default configuration if none is provided. +// 2. Validates the session encryption key length. +// 3. Initializes the logger based on the configured log level. +// 4. Sets up the HTTP client (using defaults if none provided in config). +// 5. Creates the main TraefikOidc struct, populating fields from the config +// (paths, client details, PKCE/HTTPS flags, scopes, rate limiter, caches, allowed lists). +// 6. Initializes the SessionManager. +// 7. Sets up internal function pointers/interfaces (extractClaimsFunc, initiateAuthenticationFunc, tokenVerifier, jwtVerifier, tokenExchanger). +// 8. Adds default excluded URLs. +// 9. Starts background goroutines for token cache cleanup and OIDC provider metadata initialization/refresh. // // Parameters: -// - ctx: Context for initialization operations -// - next: The next handler in the middleware chain -// - config: Configuration options for the middleware -// - name: Identifier for this middleware instance +// - ctx: The context provided by Traefik for initialization. +// - next: The next http.Handler in the Traefik middleware chain. +// - config: The plugin configuration provided by the user in Traefik static/dynamic configuration. +// - name: The name assigned to this middleware instance by Traefik. // // Returns: -// - An http.Handler that implements the middleware -// - An error if initialization fails +// - An http.Handler (the TraefikOidc instance itself, which implements ServeHTTP). +// - An error if essential configuration is missing or invalid (e.g., short encryption key). func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { if config == nil { config = CreateConfig() @@ -386,7 +426,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h return t, nil } -// initializeMetadata discovers and initializes the provider metadata +// initializeMetadata asynchronously fetches and caches the OIDC provider metadata. +// It uses the MetadataCache to retrieve potentially cached data or fetch fresh data +// via discoverProviderMetadata. On successful retrieval, it updates the middleware's +// endpoint URLs (auth, token, jwks, etc.), starts the periodic metadata refresh goroutine, +// and signals completion by closing the initComplete channel. If fetching fails initially, +// it logs an error and the middleware might remain uninitialized until a successful refresh. +// +// Parameters: +// - providerURL: The base URL of the OIDC provider. func (t *TraefikOidc) initializeMetadata(providerURL string) { t.logger.Debug("Starting provider metadata discovery") @@ -412,7 +460,12 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { t.logger.Error("Received nil metadata") } -// updateMetadataEndpoints updates the middleware with metadata endpoints +// updateMetadataEndpoints updates the relevant endpoint URL fields (jwksURL, authURL, tokenURL, etc.) +// within the TraefikOidc instance based on the discovered provider metadata. +// This is called after successfully fetching or refreshing the metadata. +// +// Parameters: +// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints. func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { t.jwksURL = metadata.JWKSURL t.authURL = metadata.AuthURL @@ -422,7 +475,13 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { t.endSessionURL = metadata.EndSessionURL } -// startMetadataRefresh periodically refreshes the OIDC metadata +// startMetadataRefresh starts a background goroutine that periodically attempts to refresh +// the OIDC provider metadata by calling GetMetadata on the metadataCache. +// It runs on a fixed ticker (currently 1 hour). Successful refreshes update the +// middleware's endpoint URLs via updateMetadataEndpoints. Fetch errors are logged. +// +// Parameters: +// - providerURL: The base URL of the OIDC provider, used for subsequent refresh attempts. func (t *TraefikOidc) startMetadataRefresh(providerURL string) { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() @@ -442,7 +501,19 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) { } } -// discoverProviderMetadata fetches the OIDC provider metadata +// discoverProviderMetadata attempts to fetch the OIDC provider's configuration from its +// well-known discovery endpoint (".well-known/openid-configuration"). +// It implements an exponential backoff retry mechanism in case of transient network errors +// or provider unavailability during startup. +// +// Parameters: +// - providerURL: The base URL of the OIDC provider. +// - httpClient: The HTTP client to use for the request. +// - l: The logger instance for recording retries and errors. +// +// Returns: +// - A pointer to the fetched ProviderMetadata struct. +// - An error if fetching fails after all retries or if a timeout is exceeded. func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) { wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration" @@ -481,7 +552,16 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr) } -// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint +// fetchMetadata performs a single attempt to fetch and decode the OIDC provider metadata +// from the specified well-known configuration URL. +// +// Parameters: +// - wellKnownURL: The full URL to the ".well-known/openid-configuration" endpoint. +// - httpClient: The HTTP client to use for the GET request. +// +// Returns: +// - A pointer to the decoded ProviderMetadata struct. +// - An error if the GET request fails, the status code is not 200 OK, or JSON decoding fails. func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) { resp, err := httpClient.Get(wellKnownURL) if err != nil { @@ -504,18 +584,22 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad return &metadata, nil } -// ServeHTTP is the main handler for the middleware that processes all HTTP requests. -// It implements the http.Handler interface and performs the following operations: -// 1. Waits for OIDC provider metadata initialization to complete -// 2. Checks if the requested URL is in the excluded list (bypassing authentication) -// 3. Retrieves or creates a user session -// 4. Handles special paths like callback and logout URLs -// 5. Verifies authentication status and token validity -// 6. Refreshes tokens that are about to expire -// 7. Validates user email domains, roles, and groups against configured restrictions -// 8. Sets appropriate headers for downstream services -// 9. Applies security headers to responses -// 10. Forwards the authenticated request to the next handler +// ServeHTTP is the main entry point for incoming requests to the middleware. +// It orchestrates the OIDC authentication flow: +// 1. Waits for initial OIDC metadata discovery to complete (with timeout). +// 2. Checks if the request path is excluded from authentication. +// 3. Checks if the request is for Server-Sent Events and bypasses if so. +// 4. Retrieves the user's session; initiates authentication if the session is invalid/missing. +// 5. Handles specific paths for OIDC callback (/callback) and logout (/logout). +// 6. Checks the user's authentication status using isUserAuthenticated (verifies token, checks expiry). +// 7. If the token is expired, handles it (initiates re-auth). +// 8. If the user is not authenticated, initiates authentication. +// 9. If the token needs proactive refresh (nearing expiry), attempts refreshToken. Handles refresh failure +// by returning 401 for API clients or initiating re-auth for browsers. +// 10. If authenticated and token is valid, performs authorization checks (allowed domain, roles/groups). +// 11. If authorized, sets user/token information in request headers (X-Forwarded-User, X-Auth-Request-*) +// and adds security headers (X-Frame-Options, etc.) to the response. +// 12. Forwards the request to the next handler in the chain. func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { select { case <-t.initComplete: @@ -698,8 +782,17 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.next.ServeHTTP(rw, req) } -// handleExpiredToken manages token expiration by clearing the session -// and initiating a new authentication flow. +// handleExpiredToken is called when a user's session contains an expired token or +// when a token refresh attempt fails for a browser client. +// It clears the authentication-related data (tokens, email, authenticated flag) from the session, +// saves the cleared session, and then initiates a new authentication flow by calling +// defaultInitiateAuthentication, redirecting the user to the OIDC provider. +// +// Parameters: +// - rw: The HTTP response writer. +// - req: The HTTP request. +// - session: The user's session data containing the expired token information. +// - redirectURL: The callback URL to be used in the new authentication flow. func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { // Clear authentication data but preserve CSRF state session.SetAuthenticated(false) @@ -717,9 +810,27 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque t.defaultInitiateAuthentication(rw, req, session, redirectURL) } -// handleCallback processes the authentication callback from the OIDC provider. -// It validates the callback parameters, exchanges the authorization code for -// tokens, verifies the tokens, and establishes the user's session. +// handleCallback handles the request received at the OIDC callback URL (redirect_uri). +// It performs the following steps: +// 1. Retrieves the user session associated with the callback request. +// 2. Checks for error parameters returned by the OIDC provider. +// 3. Validates the 'state' parameter against the CSRF token stored in the session. +// 4. Extracts the authorization 'code' from the query parameters. +// 5. Retrieves the PKCE 'code_verifier' from the session (if PKCE is enabled). +// 6. Exchanges the authorization code for tokens using the TokenExchanger interface. +// 7. Verifies the received ID token's signature and standard claims using VerifyToken. +// 8. Extracts claims from the verified ID token. +// 9. Verifies the 'nonce' claim against the nonce stored in the session. +// 10. Validates the user's email domain against the allowed list. +// 11. If all checks pass, updates the session with authentication details (status, email, tokens). +// 12. Saves the updated session. +// 13. Redirects the user back to their original requested path (stored in session) or the root path. +// If any step fails, it sends an appropriate error response using sendErrorResponse. +// +// Parameters: +// - rw: The HTTP response writer. +// - req: The incoming HTTP request to the callback URL. +// - redirectURL: The fully qualified callback URL (used in the token exchange request). func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) { session, err := t.sessionManager.GetSession(req) if err != nil { @@ -846,7 +957,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, http.Redirect(rw, req, redirectPath, http.StatusFound) } -// determineExcludedURL checks if the current request URL is in the excluded list +// determineExcludedURL checks if the provided request path matches any of the configured excluded URL prefixes. +// +// Parameters: +// - currentRequest: The path part of the incoming request URL. +// +// Returns: +// - true if the path starts with any of the prefixes in the t.excludedURLs map. +// - false otherwise. func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { for excludedURL := range t.excludedURLs { if strings.HasPrefix(currentRequest, excludedURL) { @@ -858,7 +976,15 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { return false } -// determineScheme determines the scheme (http or https) of the request +// determineScheme determines the request scheme (http or https). +// It prioritizes the X-Forwarded-Proto header if present, otherwise checks +// the TLS property of the request. Defaults to "http". +// +// Parameters: +// - req: The incoming HTTP request. +// +// Returns: +// - "https" or "http". func (t *TraefikOidc) determineScheme(req *http.Request) string { if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" { return scheme @@ -869,7 +995,14 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string { return "http" } -// determineHost determines the host of the request +// determineHost determines the request host. +// It prioritizes the X-Forwarded-Host header if present, otherwise uses the req.Host value. +// +// Parameters: +// - req: The incoming HTTP request. +// +// Returns: +// - The determined host string (e.g., "example.com:8080"). func (t *TraefikOidc) determineHost(req *http.Request) string { if host := req.Header.Get("X-Forwarded-Host"); host != "" { return host @@ -877,17 +1010,18 @@ func (t *TraefikOidc) determineHost(req *http.Request) string { return req.Host } -// isUserAuthenticated checks if the user is authenticated by validating their session and token. -// It performs a comprehensive check of the authentication state including: -// 1. Verifying the session's authenticated flag -// 2. Checking for the presence of an access token -// 3. Validating the token's signature and claims -// 4. Checking the token's expiration time +// isUserAuthenticated checks the authentication status based on the provided session data. +// It verifies the session's authenticated flag, the presence and validity of the access token (ID token), +// including signature and standard claims (using VerifyJWTSignatureAndClaims). It also checks if the +// token is within the configured refreshGracePeriod before its actual expiration. // -// Returns three boolean values: -// - authenticated: Whether the user is currently authenticated -// - needsRefresh: Whether the token is valid but will expire soon (within grace period) -// - expired: Whether the token has expired or is otherwise invalid +// Parameters: +// - session: The SessionData object for the current user. +// +// Returns: +// - authenticated (bool): True if the session is marked authenticated and the token is present and valid (signature/claims ok, not expired beyond grace). +// - needsRefresh (bool): True if the token is valid but nearing expiration (within refreshGracePeriod) OR if VerifyJWTSignatureAndClaims failed specifically due to expiration (meaning refresh might be possible). +// - expired (bool): True if the session is unauthenticated, the token is missing, or the token verification failed for reasons other than nearing/actual expiration (e.g., invalid signature, invalid claims). func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) { if !session.GetAuthenticated() { t.logger.Debug("User is not authenticated according to session") @@ -945,19 +1079,18 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo return true, false, false } -// defaultInitiateAuthentication initiates the OIDC authentication process. -// This function prepares and starts a new authentication flow by: -// 1. Generating security tokens (CSRF token and nonce) to prevent attacks -// 2. Clearing any existing session data to avoid state conflicts -// 3. Storing the original request path to redirect back after authentication -// 4. Building the authorization URL with all required OIDC parameters -// 5. Redirecting the user to the OIDC provider's authorization endpoint +// defaultInitiateAuthentication handles the process of starting an OIDC authentication flow. +// It generates necessary security values (CSRF token, nonce, PKCE verifier/challenge if enabled), +// clears any potentially stale data from the current session, stores the new security values +// and the original request URI in the session, saves the session (setting cookies), +// builds the OIDC authorization endpoint URL with required parameters, and finally +// redirects the user's browser to that URL. // // Parameters: -// - rw: The HTTP response writer for sending the redirect -// - req: The original HTTP request that triggered authentication -// - session: The user's session data for storing authentication state -// - redirectURL: The callback URL where the OIDC provider will redirect after authentication +// - rw: The HTTP response writer used to send the redirect response. +// - req: The original incoming HTTP request that requires authentication. +// - session: The user's SessionData object (potentially new or cleared). +// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance. func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { // Generate CSRF token and nonce csrfToken := uuid.NewString() @@ -1007,15 +1140,33 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req http.Redirect(rw, req, authURL, http.StatusFound) } -// verifyToken verifies the token using the token verifier interface. -// This function delegates to the configured token verifier implementation, -// which by default is the TraefikOidc instance itself (implementing the VerifyToken method). -// This design allows for easy mocking in tests and potential future extension. +// verifyToken is a wrapper method that calls the VerifyToken method of the configured +// TokenVerifier interface (which defaults to the TraefikOidc instance itself). +// This primarily exists to facilitate testing and potential future extensions where +// token verification logic might be delegated differently. +// +// Parameters: +// - token: The raw token string to verify. +// +// Returns: +// - The result of calling t.tokenVerifier.VerifyToken(token). func (t *TraefikOidc) verifyToken(token string) error { return t.tokenVerifier.VerifyToken(token) } -// buildAuthURL constructs the authentication URL with optional PKCE support +// buildAuthURL constructs the OIDC authorization endpoint URL with all necessary query parameters +// for initiating the authorization code flow. It includes client_id, response_type, redirect_uri, +// state, nonce, and optionally PKCE parameters (code_challenge, code_challenge_method) if enabled +// and a challenge is provided. It also includes configured scopes. +// +// Parameters: +// - redirectURL: The callback URL (redirect_uri). +// - state: The CSRF token. +// - nonce: The OIDC nonce. +// - codeChallenge: The PKCE code challenge (can be empty if PKCE is disabled or not used). +// +// Returns: +// - The fully constructed authorization URL string. func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string { params := url.Values{} params.Set("client_id", t.clientID) @@ -1037,7 +1188,16 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri return t.buildURLWithParams(t.authURL, params) } -// buildURLWithParams ensures a URL is absolute and appends query parameters +// buildURLWithParams takes a base URL and query parameters and constructs a full URL string. +// If the baseURL is relative (doesn't start with http/https), it prepends the scheme and host +// from the configured issuerURL. It then appends the encoded query parameters. +// +// Parameters: +// - baseURL: The base URL (can be absolute or relative to the issuer). +// - params: A url.Values map containing the query parameters to append. +// +// Returns: +// - The fully constructed URL string with appended query parameters. func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string { // Ensure URL is absolute if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { @@ -1054,7 +1214,8 @@ func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) stri return baseURL + "?" + params.Encode() } -// startTokenCleanup starts the token cleanup goroutine +// startTokenCleanup starts background goroutines for periodically cleaning up +// the token cache, token blacklist cache, and JWK cache using the autoCleanupRoutine helper. func (t *TraefikOidc) startTokenCleanup() { ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute go func() { @@ -1069,7 +1230,14 @@ func (t *TraefikOidc) startTokenCleanup() { }() } -// RevokeToken adds the token to the blacklist +// RevokeToken handles local revocation of a token. +// It removes the token from the validation cache (tokenCache) and adds the raw +// token string to the blacklist cache (tokenBlacklist) with a default expiration (24h). +// This prevents the token from being validated successfully even if it hasn't expired yet. +// Note: This does *not* revoke the token with the OIDC provider. +// +// Parameters: +// - token: The raw token string to revoke locally. func (t *TraefikOidc) RevokeToken(token string) { // Remove from cache t.tokenCache.Delete(token) @@ -1080,7 +1248,17 @@ func (t *TraefikOidc) RevokeToken(token string) { t.tokenBlacklist.Set(token, true, time.Until(expiry)) } -// RevokeTokenWithProvider revokes the token with the provider +// RevokeTokenWithProvider attempts to revoke a token directly with the OIDC provider +// using the revocation endpoint specified in the provider metadata or configuration. +// It sends a POST request with the token, token_type_hint, client_id, and client_secret. +// +// Parameters: +// - token: The token (e.g., refresh token or access token) to revoke. +// - tokenType: The type hint for the token being revoked (e.g., "refresh_token"). +// +// Returns: +// - nil if the revocation request is successful (provider returns 200 OK). +// - An error if the request fails or the provider returns a non-OK status. func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { t.logger.Debugf("Revoking token with provider") @@ -1117,7 +1295,20 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { return nil } -// refreshToken refreshes the user's token, protected by a mutex within the session. +// refreshToken attempts to use the refresh token stored in the session to obtain a new set of tokens. +// It acquires a mutex associated with the session to prevent concurrent refresh attempts for the same session. +// It retrieves the refresh token, calls the TokenExchanger's GetNewTokenWithRefreshToken method, +// verifies the newly obtained ID token using verifyToken, updates the session with the new tokens, +// and saves the session. +// +// Parameters: +// - rw: The HTTP response writer (needed for saving the updated session). +// - req: The HTTP request (needed for saving the updated session). +// - session: The user's SessionData object containing the refresh token. +// +// Returns: +// - true if the token refresh was successful and the session was updated. +// - false if no refresh token was found, the refresh exchange failed, the new token failed verification, or saving the session failed. func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool { // Lock the mutex specific to this session instance before attempting refresh session.refreshMutex.Lock() @@ -1159,7 +1350,16 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return true } -// isAllowedDomain checks if the user's email domain is allowed +// isAllowedDomain checks if the domain part of the provided email address is present +// in the configured list of allowed domains (t.allowedUserDomains). +// If the allowed domains list is empty, all domains are considered allowed. +// +// Parameters: +// - email: The email address to check. +// +// Returns: +// - true if the domain is allowed or if no domain restrictions are configured. +// - false if the email format is invalid or the domain is not in the allowed list. func (t *TraefikOidc) isAllowedDomain(email string) bool { if len(t.allowedUserDomains) == 0 { return true // If no domains are specified, all are allowed @@ -1175,7 +1375,18 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool { return ok } -// extractGroupsAndRoles extracts groups and roles from the id_token +// extractGroupsAndRoles attempts to extract 'groups' and 'roles' claims from a decoded ID token. +// It expects these claims, if present, to be arrays of strings. +// It uses the configured extractClaimsFunc (which defaults to the package-level extractClaims) +// to get the claims map from the token string. +// +// Parameters: +// - idToken: The raw ID token string. +// +// Returns: +// - A slice of strings containing the groups found in the 'groups' claim. +// - A slice of strings containing the roles found in the 'roles' claim. +// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present but not arrays of strings. func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) { claims, err := t.extractClaimsFunc(idToken) if err != nil { @@ -1216,7 +1427,17 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, return groups, roles, nil } -// buildFullURL constructs a full URL from scheme, host and path +// buildFullURL constructs an absolute URL string from its components. +// If the provided path already starts with "http://" or "https://", it's returned directly. +// Otherwise, it combines the scheme, host, and path, ensuring the path starts with a '/'. +// +// Parameters: +// - scheme: The URL scheme ("http" or "https"). +// - host: The host part of the URL (e.g., "example.com:8080"). +// - path: The path part of the URL (e.g., "/resource"). +// +// Returns: +// - The combined absolute URL string (e.g., "https://example.com:8080/resource"). func buildFullURL(scheme, host, path string) string { // If the path is already a full URL, return it as-is if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { @@ -1233,21 +1454,35 @@ func buildFullURL(scheme, host, path string) string { // --- TokenExchanger Interface Implementation --- -// ExchangeCodeForToken implements the TokenExchanger interface. -// It calls the existing exchangeTokens helper function. +// ExchangeCodeForToken provides the implementation for the TokenExchanger interface method. +// It directly calls the internal exchangeTokens method, passing through the arguments. +// This allows the TraefikOidc struct to act as its own default TokenExchanger, while +// still allowing mocking for tests. func (t *TraefikOidc) ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Note: The original exchangeTokens helper is defined in helpers.go and is already a method on *TraefikOidc return t.exchangeTokens(ctx, grantType, codeOrToken, redirectURL, codeVerifier) } -// GetNewTokenWithRefreshToken implements the TokenExchanger interface. -// It calls the existing getNewTokenWithRefreshToken helper function. +// GetNewTokenWithRefreshToken provides the implementation for the TokenExchanger interface method. +// It directly calls the internal getNewTokenWithRefreshToken helper method. +// This allows the TraefikOidc struct to act as its own default TokenExchanger, while +// still allowing mocking for tests. func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { // Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc return t.getNewTokenWithRefreshToken(refreshToken) } -// sendErrorResponse sends an error response, adapting to the client's Accept header. +// sendErrorResponse sends an error response to the client, adapting the format based +// on the request's Accept header. If the client prefers "application/json", it sends +// a JSON object with "error", "error_description", and "status_code" fields. +// Otherwise, it sends a basic HTML error page containing the message and a link +// back to the application root or the original incoming path (if available from the session). +// +// Parameters: +// - rw: The HTTP response writer. +// - req: The HTTP request (used to check Accept header and potentially get session). +// - message: The error message to display/include in the response. +// - code: The HTTP status code to set for the response. func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) { acceptHeader := req.Header.Get("Accept") diff --git a/main_test.go b/main_test.go index fff7910..b86c325 100644 --- a/main_test.go +++ b/main_test.go @@ -376,6 +376,7 @@ func TestServeHTTP(t *testing.T) { setupSession func(*SessionData) mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks + requestHeaders map[string]string // Added for setting headers like Accept }{ { name: "Excluded URL", @@ -394,7 +395,13 @@ func TestServeHTTP(t *testing.T) { setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") - session.SetAccessToken(ts.token) // Use the valid token generated in Setup + // Generate a fresh valid token for this test case to avoid replay issues + freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com", + "jti": generateRandomString(16), // Unique JTI + }) + session.SetAccessToken(freshToken) session.SetRefreshToken("valid-refresh-token") }, expectedStatus: http.StatusOK, @@ -450,14 +457,161 @@ func TestServeHTTP(t *testing.T) { }, { name: "Logout URL", - requestPath: "/logout", // Assuming logout path is configured or defaulted correctly + requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") - session.SetAccessToken(ts.token) + // Generate a fresh valid token for this test case + freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com", + "jti": generateRandomString(16), // Unique JTI + }) + session.SetAccessToken(freshToken) }, expectedStatus: http.StatusFound, // Expect redirect after logout expectedBody: "", + // No specific session assertion needed for logout redirect itself + }, + { + name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)", + requestPath: "/protected", + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetAccessToken(createExpiredToken()) + session.SetRefreshToken("valid-refresh-token") + }, + mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { + return func(refreshToken string) (*TokenResponse, error) { + // Simulate failed refresh + return nil, fmt.Errorf("mock error: refresh token invalid or provider down") + } + }, + requestHeaders: map[string]string{ + "Accept": "application/json", + }, + expectedStatus: http.StatusUnauthorized, // Expect 401 for API client + expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`, + }, + { + name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)", + requestPath: "/protected", + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetAccessToken(createExpiredToken()) + session.SetRefreshToken("valid-refresh-token") + }, + mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { + return func(refreshToken string) (*TokenResponse, error) { + // Simulate failed refresh + return nil, fmt.Errorf("mock error: refresh token invalid or provider down") + } + }, + requestHeaders: map[string]string{ + "Accept": "text/html", // Browser client + }, + expectedStatus: http.StatusFound, // Expect redirect for browser client + }, + { + name: "Authenticated request with token nearing expiry (needs refresh)", + requestPath: "/protected", + setupSession: func(session *SessionData) { + // Create token expiring soon (e.g., 30s, within default 60s grace period) + exp := time.Now().Add(30 * time.Second).Unix() + iat := time.Now().Add(-1 * time.Minute).Unix() + nbf := time.Now().Add(-1 * time.Minute).Unix() + nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, + "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), + }) + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetAccessToken(nearExpiryToken) + session.SetRefreshToken("valid-refresh-token-for-near-expiry") + }, + mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { + return func(refreshToken string) (*TokenResponse, error) { + if refreshToken != "valid-refresh-token-for-near-expiry" { + return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken) + } + // Simulate successful refresh + newToken := createNewValidToken() + return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-near-expiry", ExpiresIn: 3600}, nil + } + }, + expectedStatus: http.StatusOK, // Expect success after proactive refresh + expectedBody: "OK", + }, + { + name: "Authenticated request with token valid (outside grace period)", + requestPath: "/protected", + setupSession: func(session *SessionData) { + // Create token expiring later (e.g., 10 mins, outside default 60s grace period) + exp := time.Now().Add(10 * time.Minute).Unix() + iat := time.Now().Add(-1 * time.Minute).Unix() + nbf := time.Now().Add(-1 * time.Minute).Unix() + validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, + "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), + }) + session.SetAuthenticated(true) + session.SetEmail("user@example.com") + session.SetAccessToken(validToken) + session.SetRefreshToken("should-not-be-used-refresh-token") + }, + mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { + // This should NOT be called + return func(refreshToken string) (*TokenResponse, error) { + t.Errorf("Refresh token function was called unexpectedly for valid token outside grace period") + return nil, fmt.Errorf("refresh should not have been attempted") + } + }, + expectedStatus: http.StatusOK, // Expect success, no refresh needed + expectedBody: "OK", + }, + { + name: "Disallowed Domain (Accept: JSON)", + requestPath: "/protected", + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@disallowed.com") // Use disallowed domain + // Generate a fresh valid token for this test case + freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email + "jti": generateRandomString(16), // Unique JTI + }) + session.SetAccessToken(freshToken) + session.SetRefreshToken("valid-refresh-token") + }, + requestHeaders: map[string]string{ + "Accept": "application/json", + }, + expectedStatus: http.StatusForbidden, + expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`, + }, + { + name: "Disallowed Domain (Accept: HTML)", + requestPath: "/protected", + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) + session.SetEmail("user@disallowed.com") // Use disallowed domain + // Generate a fresh valid token for this test case + freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email + "jti": generateRandomString(16), // Unique JTI + }) + session.SetAccessToken(freshToken) + session.SetRefreshToken("valid-refresh-token") + }, + requestHeaders: map[string]string{ + "Accept": "text/html", + }, + expectedStatus: http.StatusForbidden, // Still Forbidden, but HTML response + expectedBody: "", // Body check is harder for HTML, focus on status and content-type }, } @@ -468,6 +622,12 @@ func TestServeHTTP(t *testing.T) { req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that req.Header.Set("X-Forwarded-Host", "testhost.com") req.Host = "testhost.com" // Also set Host header + // Set request headers from test case + if tc.requestHeaders != nil { + for key, value := range tc.requestHeaders { + req.Header.Set(key, value) + } + } rr := httptest.NewRecorder() @@ -527,10 +687,19 @@ func TestServeHTTP(t *testing.T) { } // Check response body if expected + // Check response body if expected (handle JSON vs HTML) if tc.expectedBody != "" { - if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody { - t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body) + // For JSON, compare directly + if strings.Contains(rr.Header().Get("Content-Type"), "application/json") { + if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody { + t.Errorf("Test %s: Expected JSON body %q, got %q", tc.name, tc.expectedBody, body) + } + } else if tc.expectedBody == "OK" { // Simple check for the "OK" body from next handler + if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody { + t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body) + } } + // Add more sophisticated HTML body checks if needed } // Perform post-request session assertions if defined @@ -2319,3 +2488,130 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) { t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath) } } + +// TestVerifyTimeConstraint tests the time constraint verification logic with separate past/future skew tolerances. +func TestVerifyTimeConstraint(t *testing.T) { + // Define tolerances used in jwt.go (ensure they match) + toleranceFuture := 2 * time.Minute + tolerancePast := 10 * time.Second + + now := time.Now() + + tests := []struct { + name string + claimTime time.Time + claimName string + futureCheck bool // true for exp, false for iat/nbf + expectError bool + }{ + // Expiration (future=true, tolerance=2min) + { + name: "EXP: Valid (expires in 1 min)", + claimTime: now.Add(1 * time.Minute), + claimName: "exp", + futureCheck: true, + expectError: false, + }, + { + name: "EXP: Expired (expired 3 min ago)", + claimTime: now.Add(-3 * time.Minute), // Outside 2min tolerance + claimName: "exp", + futureCheck: true, + expectError: true, + }, + { + name: "EXP: Valid (expired 1 min ago, within 2min tolerance)", + claimTime: now.Add(-1 * time.Minute), // Inside 2min tolerance + claimName: "exp", + futureCheck: true, + expectError: false, // Should be allowed due to future tolerance + }, + + // Issued At (future=false, tolerance=10s) + { + name: "IAT: Valid (issued 1 min ago)", + claimTime: now.Add(-1 * time.Minute), + claimName: "iat", + futureCheck: false, + expectError: false, + }, + { + name: "IAT: Invalid (issued 15 sec in future)", + claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance + claimName: "iat", + futureCheck: false, + expectError: true, // "token used before issued" + }, + { + name: "IAT: Valid (issued 5 sec in future, within 10s tolerance)", + claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance + claimName: "iat", + futureCheck: false, + expectError: false, // Should be allowed due to past tolerance + }, + + // Not Before (future=false, tolerance=10s) + { + name: "NBF: Valid (active 1 min ago)", + claimTime: now.Add(-1 * time.Minute), + claimName: "nbf", + futureCheck: false, + expectError: false, + }, + { + name: "NBF: Invalid (active in 15 sec)", + claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance + claimName: "nbf", + futureCheck: false, + expectError: true, // "token not yet valid" + }, + { + name: "NBF: Valid (active in 5 sec, within 10s tolerance)", + claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance + claimName: "nbf", + futureCheck: false, + expectError: false, // Should be allowed due to past tolerance + }, + } + + // Temporarily adjust global tolerances for test consistency, then restore + originalFutureTolerance := ClockSkewToleranceFuture + originalPastTolerance := ClockSkewTolerancePast + ClockSkewToleranceFuture = toleranceFuture + ClockSkewTolerancePast = tolerancePast + defer func() { + ClockSkewToleranceFuture = originalFutureTolerance + ClockSkewTolerancePast = originalPastTolerance + }() + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Convert claim time to float64 unix timestamp + unixTime := float64(tc.claimTime.Unix()) + float64(tc.claimTime.Nanosecond())/1e9 + + var err error + // Call the specific verification function which uses verifyTimeConstraint + if tc.claimName == "exp" { + err = verifyExpiration(unixTime) + } else if tc.claimName == "iat" { + err = verifyIssuedAt(unixTime) + } else if tc.claimName == "nbf" { + err = verifyNotBefore(unixTime) + } else { + t.Fatalf("Unknown claim name in test setup: %s", tc.claimName) + } + + if tc.expectError { + if err == nil { + t.Errorf("Expected error for claim %s at time %v (now=%v), but got nil", tc.claimName, tc.claimTime, now) + } else { + t.Logf("Got expected error: %v", err) // Log the error for confirmation + } + } else { + if err != nil { + t.Errorf("Expected no error for claim %s at time %v (now=%v), but got: %v", tc.claimName, tc.claimTime, now, err) + } + } + }) + } +} // Add missing closing brace for TestVerifyTimeConstraint diff --git a/metadata_cache.go b/metadata_cache.go index 0e96d59..60390d8 100644 --- a/metadata_cache.go +++ b/metadata_cache.go @@ -15,6 +15,8 @@ type MetadataCache struct { stopCleanup chan struct{} } +// NewMetadataCache creates a new MetadataCache instance. +// It initializes the cache structure and starts the background cleanup goroutine. func NewMetadataCache() *MetadataCache { c := &MetadataCache{ autoCleanupInterval: 5 * time.Minute, @@ -24,7 +26,8 @@ func NewMetadataCache() *MetadataCache { return c } -// Cleanup removes expired metadata from the cache. +// Cleanup removes the cached provider metadata if it has expired. +// This is called periodically by the auto-cleanup goroutine. func (c *MetadataCache) Cleanup() { c.mutex.Lock() defer c.mutex.Unlock() @@ -35,11 +38,31 @@ func (c *MetadataCache) Cleanup() { } } +// isCacheValid checks if the cached metadata is present and has not expired. +// Note: This function assumes the read lock is held or it's called from a context +// where the lock is already held (like within GetMetadata after locking). func (c *MetadataCache) isCacheValid() bool { return c.metadata != nil && time.Now().Before(c.expiresAt) } -// GetMetadata retrieves the metadata from cache or fetches it if expired +// GetMetadata retrieves the OIDC provider metadata. +// It first checks the cache for valid, non-expired metadata. If found, it's returned immediately. +// If the cache is empty or expired, it attempts to fetch the metadata from the provider's +// well-known endpoint using discoverProviderMetadata. +// If fetching is successful, the new metadata is cached for 1 hour. +// If fetching fails but valid metadata exists in the cache (even if expired), the cache expiry +// is extended by 5 minutes, and the cached data is returned to prevent thundering herd issues. +// If fetching fails and there's no cached data, an error is returned. +// It employs double-checked locking for thread safety and performance. +// +// Parameters: +// - providerURL: The base URL of the OIDC provider. +// - httpClient: The HTTP client to use for fetching metadata. +// - logger: The logger instance for recording errors or warnings. +// +// Returns: +// - A pointer to the ProviderMetadata struct. +// - An error if metadata cannot be retrieved from cache or fetched from the provider. func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) { c.mutex.RLock() if c.isCacheValid() { @@ -76,10 +99,13 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, return metadata, nil } +// startAutoCleanup starts the background goroutine that periodically calls Cleanup +// to remove expired metadata from the cache. func (c *MetadataCache) startAutoCleanup() { autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup) } +// Close stops the automatic cleanup goroutine associated with this metadata cache. func (c *MetadataCache) Close() { close(c.stopCleanup) } diff --git a/session.go b/session.go index 14ff0e8..a7781bd 100644 --- a/session.go +++ b/session.go @@ -16,8 +16,15 @@ import ( "github.com/gorilla/sessions" ) -// generateSecureRandomString creates a cryptographically secure random string of specified length. -// It returns the generated string or an error if random generation fails. +// generateSecureRandomString creates a cryptographically secure, hex-encoded random string. +// It reads the specified number of bytes from crypto/rand and encodes them as a hexadecimal string. +// +// Parameters: +// - length: The number of random bytes to generate (the resulting hex string will be twice this length). +// +// Returns: +// - A hex-encoded random string. +// - An error if reading random bytes fails. func generateSecureRandomString(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { @@ -56,7 +63,14 @@ const ( minEncryptionKeyLength = 32 ) -// compressToken compresses a token using gzip and base64 encodes it. +// compressToken compresses the input string using gzip and then encodes the result using standard base64 encoding. +// If any error occurs during compression, it returns the original uncompressed token as a fallback. +// +// Parameters: +// - token: The string to compress. +// +// Returns: +// - The base64 encoded, gzipped string, or the original string if compression fails. func compressToken(token string) string { var b bytes.Buffer gz := gzip.NewWriter(&b) @@ -69,7 +83,15 @@ func compressToken(token string) string { return base64.StdEncoding.EncodeToString(b.Bytes()) } -// decompressToken decompresses a base64 encoded gzipped token. +// decompressToken decodes a standard base64 encoded string and then decompresses the result using gzip. +// If base64 decoding or gzip decompression fails, it returns the original input string as a fallback, +// assuming it might not have been compressed. +// +// Parameters: +// - compressed: The base64 encoded, gzipped string. +// +// Returns: +// - The decompressed original string, or the input string if decompression fails. func decompressToken(compressed string) string { data, err := base64.StdEncoding.DecodeString(compressed) if err != nil { @@ -140,15 +162,15 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (* return sm, nil } -// getSessionOptions returns secure session options configured for the current request. -// Parameters: -// - isSecure: Whether the current request is using HTTPS. +// getSessionOptions returns a sessions.Options struct configured with security best practices. +// It sets HttpOnly to true, Secure based on the request scheme or forceHTTPS setting, +// SameSite to LaxMode, MaxAge to the absoluteSessionTimeout, and Path to "/". // -// The options ensure cookies are: -// - HTTP-only (not accessible via JavaScript) -// - Secure when using HTTPS or when forceHTTPS is enabled -// - Using SameSite=Lax for CSRF protection -// - Set with appropriate timeout and path settings +// Parameters: +// - isSecure: A boolean indicating if the current request context is secure (HTTPS). +// +// Returns: +// - A pointer to a configured sessions.Options struct. func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options { return &sessions.Options{ HttpOnly: true, @@ -210,11 +232,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { return sessionData, nil } -// getTokenChunkSessions retrieves all session chunks for a given token type. +// getTokenChunkSessions retrieves all cookie chunks associated with a large token (access or refresh). +// It iteratively attempts to load cookies named "{baseName}_0", "{baseName}_1", etc., until +// a cookie is not found or returns an error. The loaded sessions are stored in the provided chunks map. +// // Parameters: -// - r: The HTTP request -// - baseName: The base name for the token's session cookies -// - chunks: Map to store the chunks in +// - r: The incoming HTTP request containing the cookies. +// - baseName: The base name of the cookie (e.g., accessTokenCookie). +// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks) to populate with the found session chunks. func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) { for i := 0; ; i++ { sessionName := fmt.Sprintf("%s_%d", baseName, i) @@ -258,10 +283,16 @@ type SessionData struct { refreshMutex sync.Mutex } -// Save persists all session data to cookies in the HTTP response. -// It saves the main session, token sessions, and any token chunks, -// applying appropriate security options to each cookie. All cookies -// are saved with consistent security settings based on the request scheme. +// Save persists all parts of the session (main, access token, refresh token, and any chunks) +// back to the client as cookies in the HTTP response. It applies secure cookie options +// obtained via getSessionOptions based on the request's security context. +// +// Parameters: +// - r: The original HTTP request (used to determine security context for cookie options). +// - w: The HTTP response writer to which the Set-Cookie headers will be added. +// +// Returns: +// - An error if saving any of the session components fails. func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS @@ -305,7 +336,19 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { return nil } -// Clear removes all session data by expiring all cookies and clearing their values. +// Clear removes all session data associated with this SessionData instance. +// It clears the values map of the main, access, and refresh sessions, sets their MaxAge to -1 +// to expire the cookies immediately, and clears any associated token chunk cookies. +// If a ResponseWriter is provided, it attempts to save the expired sessions to send the +// expiring Set-Cookie headers. Finally, it clears internal fields and returns the SessionData +// object to the pool. +// +// Parameters: +// - r: The HTTP request (required by the underlying session store). +// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent. +// +// Returns: +// - An error if saving the expired sessions fails (only if w is not nil). func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { // Clear and expire all sessions. sd.mainSession.Options.MaxAge = -1 @@ -340,7 +383,12 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { return err } -// clearTokenChunks removes all session chunks for a given token type. +// clearTokenChunks iterates through a map of session chunks, clears their values, +// and sets their MaxAge to -1 to expire them. This is used internally by Clear. +// +// Parameters: +// - r: The HTTP request (required by the underlying session store, though not directly used here). +// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire. func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) { for _, session := range chunks { session.Options.MaxAge = -1 @@ -350,7 +398,12 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session } } -// GetAuthenticated returns whether the current session is authenticated. +// GetAuthenticated checks if the session is marked as authenticated and has not exceeded +// the absolute session timeout. +// +// Returns: +// - true if the "authenticated" flag is set to true and the session creation time is within the allowed timeout. +// - false otherwise. func (sd *SessionData) GetAuthenticated() bool { auth, _ := sd.mainSession.Values["authenticated"].(bool) if !auth { @@ -365,8 +418,15 @@ func (sd *SessionData) GetAuthenticated() bool { return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout } -// SetAuthenticated updates the session's authentication status and rotates session ID. -// Returns an error if generating a new session ID fails. +// SetAuthenticated sets the authentication status of the session. +// If setting to true, it generates a new secure session ID for the main session +// to prevent session fixation attacks and records the current time as the creation time. +// +// Parameters: +// - value: The boolean authentication status (true for authenticated, false otherwise). +// +// Returns: +// - An error if generating a new session ID fails when setting value to true. func (sd *SessionData) SetAuthenticated(value bool) error { if value { id, err := generateSecureRandomString(32) @@ -380,7 +440,12 @@ func (sd *SessionData) SetAuthenticated(value bool) error { return nil } -// GetAccessToken retrieves the complete access token from the session. +// GetAccessToken retrieves the access token stored in the session. +// It handles reassembling the token from multiple cookie chunks if necessary +// and decompresses it if it was stored compressed. +// +// Returns: +// - The complete, decompressed access token string, or an empty string if not found. func (sd *SessionData) GetAccessToken() string { token, _ := sd.accessSession.Values["token"].(string) if token != "" { @@ -414,7 +479,14 @@ func (sd *SessionData) GetAccessToken() string { return token } -// SetAccessToken stores the access token in the session. +// SetAccessToken stores the provided access token in the session. +// It first expires any existing access token chunk cookies. +// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize), +// it's stored directly in the primary access token session. Otherwise, the compressed token +// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_a_0, _oidc_raczylo_a_1, etc.). +// +// Parameters: +// - token: The access token string to store. func (sd *SessionData) SetAccessToken(token string) { // Expire any existing chunk cookies first. if sd.request != nil { @@ -444,7 +516,12 @@ func (sd *SessionData) SetAccessToken(token string) { } } -// GetRefreshToken retrieves the complete refresh token from the session. +// GetRefreshToken retrieves the refresh token stored in the session. +// It handles reassembling the token from multiple cookie chunks if necessary +// and decompresses it if it was stored compressed. +// +// Returns: +// - The complete, decompressed refresh token string, or an empty string if not found. func (sd *SessionData) GetRefreshToken() string { token, _ := sd.refreshSession.Values["token"].(string) if token != "" { @@ -478,7 +555,14 @@ func (sd *SessionData) GetRefreshToken() string { return token } -// SetRefreshToken stores the refresh token in the session. +// SetRefreshToken stores the provided refresh token in the session. +// It first expires any existing refresh token chunk cookies. +// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize), +// it's stored directly in the primary refresh token session. Otherwise, the compressed token +// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_r_0, _oidc_raczylo_r_1, etc.). +// +// Parameters: +// - token: The refresh token string to store. func (sd *SessionData) SetRefreshToken(token string) { // Expire any existing chunk cookies first. if sd.request != nil { @@ -508,7 +592,13 @@ func (sd *SessionData) SetRefreshToken(token string) { } } -// expireAccessTokenChunks expires any existing access token chunk cookies. +// expireAccessTokenChunks finds all existing access token chunk cookies (_oidc_raczylo_a_N) +// associated with the current request, clears their values, and sets their MaxAge to -1. +// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send +// the expiring Set-Cookie headers. This is used internally when setting a new access token. +// +// Parameters: +// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent. func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) { for i := 0; ; i++ { sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i) @@ -526,7 +616,13 @@ func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) { } } -// expireRefreshTokenChunks expires any existing refresh token chunk cookies. +// expireRefreshTokenChunks finds all existing refresh token chunk cookies (_oidc_raczylo_r_N) +// associated with the current request, clears their values, and sets their MaxAge to -1. +// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send +// the expiring Set-Cookie headers. This is used internally when setting a new refresh token. +// +// Parameters: +// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent. func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) { for i := 0; ; i++ { sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i) @@ -544,7 +640,15 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) { } } -// splitIntoChunks splits a string into chunks of specified size. +// splitIntoChunks divides a string `s` into a slice of strings, where each element +// has a maximum length of `chunkSize`. +// +// Parameters: +// - s: The string to split. +// - chunkSize: The maximum size of each chunk. +// +// Returns: +// - A slice of strings representing the chunks. func splitIntoChunks(s string, chunkSize int) []string { var chunks []string for len(s) > 0 { @@ -559,57 +663,97 @@ func splitIntoChunks(s string, chunkSize int) []string { return chunks } -// GetCSRF retrieves the CSRF token from the session. +// GetCSRF retrieves the Cross-Site Request Forgery (CSRF) token stored in the main session. +// +// Returns: +// - The CSRF token string, or an empty string if not set. func (sd *SessionData) GetCSRF() string { csrf, _ := sd.mainSession.Values["csrf"].(string) return csrf } -// SetCSRF stores a new CSRF token in the session. +// SetCSRF stores the provided CSRF token string in the main session. +// This token is typically generated at the start of the authentication flow. +// +// Parameters: +// - token: The CSRF token to store. func (sd *SessionData) SetCSRF(token string) { sd.mainSession.Values["csrf"] = token } -// GetNonce retrieves the nonce value from the session. +// GetNonce retrieves the OIDC nonce value stored in the main session. +// The nonce is used to associate an ID token with the specific authentication request. +// +// Returns: +// - The nonce string, or an empty string if not set. func (sd *SessionData) GetNonce() string { nonce, _ := sd.mainSession.Values["nonce"].(string) return nonce } -// SetNonce stores a new nonce value in the session. +// SetNonce stores the provided OIDC nonce string in the main session. +// This nonce is typically generated at the start of the authentication flow. +// +// Parameters: +// - nonce: The nonce string to store. func (sd *SessionData) SetNonce(nonce string) { sd.mainSession.Values["nonce"] = nonce } -// GetCodeVerifier retrieves the PKCE code verifier from the session. +// GetCodeVerifier retrieves the PKCE (Proof Key for Code Exchange) code verifier +// stored in the main session. This is only relevant if PKCE is enabled. +// +// Returns: +// - The code verifier string, or an empty string if not set or PKCE is disabled. func (sd *SessionData) GetCodeVerifier() string { codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string) return codeVerifier } -// SetCodeVerifier stores the PKCE code verifier in the session. +// SetCodeVerifier stores the provided PKCE code verifier string in the main session. +// This is typically called at the start of the authentication flow if PKCE is enabled. +// +// Parameters: +// - codeVerifier: The PKCE code verifier string to store. func (sd *SessionData) SetCodeVerifier(codeVerifier string) { sd.mainSession.Values["code_verifier"] = codeVerifier } -// GetEmail retrieves the authenticated user's email address from the session. +// GetEmail retrieves the authenticated user's email address stored in the main session. +// This is typically extracted from the ID token claims after successful authentication. +// +// Returns: +// - The user's email address string, or an empty string if not set. func (sd *SessionData) GetEmail() string { email, _ := sd.mainSession.Values["email"].(string) return email } -// SetEmail stores the user's email address in the session. +// SetEmail stores the provided user email address string in the main session. +// This is typically called after successful authentication and claim extraction. +// +// Parameters: +// - email: The user's email address to store. func (sd *SessionData) SetEmail(email string) { sd.mainSession.Values["email"] = email } -// GetIncomingPath retrieves the original request path that triggered the authentication flow. +// GetIncomingPath retrieves the original request URI (including query parameters) +// that the user was trying to access before being redirected for authentication. +// This is stored in the main session to allow redirection back after successful login. +// +// Returns: +// - The original request URI string, or an empty string if not set. func (sd *SessionData) GetIncomingPath() string { path, _ := sd.mainSession.Values["incoming_path"].(string) return path } -// SetIncomingPath stores the original request path that triggered the authentication flow. +// SetIncomingPath stores the original request URI (path and query parameters) +// in the main session. This is typically called at the start of the authentication flow. +// +// Parameters: +// - path: The original request URI string (e.g., "/protected/resource?id=123"). func (sd *SessionData) SetIncomingPath(path string) { sd.mainSession.Values["incoming_path"] = path } diff --git a/settings.go b/settings.go index 2965767..c737cc5 100644 --- a/settings.go +++ b/settings.go @@ -114,6 +114,14 @@ const ( // - PostLogoutRedirectURI: "/" // - ForceHTTPS: true (for security) // - EnablePKCE: false (PKCE is opt-in) +// +// CreateConfig initializes a new Config struct with default values for optional fields. +// It sets default scopes, log level, rate limit, enables ForceHTTPS, and sets the +// default refresh grace period. Required fields like ProviderURL, ClientID, ClientSecret, +// CallbackURL, and SessionEncryptionKey must be set explicitly after creation. +// +// Returns: +// - A pointer to a new Config struct with default settings applied. func CreateConfig() *Config { c := &Config{ Scopes: []string{"openid", "profile", "email"}, @@ -127,9 +135,14 @@ func CreateConfig() *Config { return c } -// Validate performs validation checks on the Config. -// It ensures all required fields are set and have valid values. -// Returns an error if any validation check fails. +// Validate checks the configuration settings for validity. +// It ensures that required fields (ProviderURL, CallbackURL, ClientID, ClientSecret, SessionEncryptionKey) +// are present and that URLs are well-formed (HTTPS where required). It also validates +// the session key length, log level, rate limit, and refresh grace period. +// +// Returns: +// - nil if the configuration is valid. +// - An error describing the first validation failure encountered. func (c *Config) Validate() error { // Validate provider URL if c.ProviderURL == "" { @@ -211,13 +224,26 @@ func (c *Config) Validate() error { return nil } -// isValidSecureURL checks if the provided string is a valid HTTPS URL +// isValidSecureURL checks if a given string represents a valid, absolute HTTPS URL. +// It uses url.Parse and checks for a nil error, an "https" scheme, and a non-empty host. +// +// Parameters: +// - s: The URL string to validate. +// +// Returns: +// - true if the string is a valid HTTPS URL, false otherwise. func isValidSecureURL(s string) bool { u, err := url.Parse(s) return err == nil && u.Scheme == "https" && u.Host != "" } -// isValidLogLevel checks if the provided log level is valid +// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error"). +// +// Parameters: +// - level: The log level string to validate. +// +// Returns: +// - true if the log level is valid, false otherwise. func isValidLogLevel(level string) bool { return level == "debug" || level == "info" || level == "error" } @@ -234,14 +260,20 @@ type Logger struct { logDebug *log.Logger } -// NewLogger creates a new Logger with the specified log level. -// The log level determines which messages are output: -// - "debug": Outputs all messages (debug, info, error) -// - "info": Outputs info and error messages -// - "error": Outputs only error messages +// NewLogger creates and configures a new Logger instance based on the provided log level. +// It initializes loggers for ERROR (stderr), INFO (stdout), and DEBUG (stdout) levels, +// enabling output based on the specified level: +// - "error": Only ERROR messages are output. +// - "info": INFO and ERROR messages are output. +// - "debug": DEBUG, INFO, and ERROR messages are output. // -// Error messages are always written to stderr, while info and debug -// messages are written to stdout when enabled. +// If an invalid level is provided, it defaults to behavior similar to "error". +// +// Parameters: +// - logLevel: The desired logging level ("debug", "info", or "error"). +// +// Returns: +// - A pointer to the configured Logger instance. func NewLogger(logLevel string) *Logger { logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime) logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime) @@ -263,51 +295,77 @@ func NewLogger(logLevel string) *Logger { } } -// Info logs an informational message. -// These messages are intended for general operational information -// and are written to stdout. +// Info logs a message at the INFO level using Printf style formatting. +// Output is directed to stdout if the configured log level is "info" or "debug". +// +// Parameters: +// - format: The format string (as in fmt.Printf). +// - args: The arguments for the format string. func (l *Logger) Info(format string, args ...interface{}) { l.logInfo.Printf(format, args...) } -// Debug logs a debug message. -// These messages are only output when debug level logging is enabled -// and are intended for detailed troubleshooting information. +// Debug logs a message at the DEBUG level using Printf style formatting. +// Output is directed to stdout only if the configured log level is "debug". +// +// Parameters: +// - format: The format string (as in fmt.Printf). +// - args: The arguments for the format string. func (l *Logger) Debug(format string, args ...interface{}) { l.logDebug.Printf(format, args...) } -// Error logs an error message. -// These messages indicate problems that need attention and are -// always written to stderr regardless of the log level. +// Error logs a message at the ERROR level using Printf style formatting. +// Output is always directed to stderr, regardless of the configured log level. +// +// Parameters: +// - format: The format string (as in fmt.Printf). +// - args: The arguments for the format string. func (l *Logger) Error(format string, args ...interface{}) { l.logError.Printf(format, args...) } -// Infof logs an informational message using Printf formatting. -// These messages are intended for general operational information -// and are written to stdout. +// Infof logs a message at the INFO level using Printf style formatting. +// Equivalent to calling l.Info(format, args...). +// Output is directed to stdout if the configured log level is "info" or "debug". +// +// Parameters: +// - format: The format string (as in fmt.Printf). +// - args: The arguments for the format string. func (l *Logger) Infof(format string, args ...interface{}) { l.logInfo.Printf(format, args...) } -// Debugf logs a debug message using Printf formatting. -// These messages are only output when debug level logging is enabled -// and are intended for detailed troubleshooting information. +// Debugf logs a message at the DEBUG level using Printf style formatting. +// Equivalent to calling l.Debug(format, args...). +// Output is directed to stdout only if the configured log level is "debug". +// +// Parameters: +// - format: The format string (as in fmt.Printf). +// - args: The arguments for the format string. func (l *Logger) Debugf(format string, args ...interface{}) { l.logDebug.Printf(format, args...) } -// Errorf logs an error message using Printf formatting. -// These messages indicate problems that need attention and are -// always written to stderr regardless of the log level. +// Errorf logs a message at the ERROR level using Printf style formatting. +// Equivalent to calling l.Error(format, args...). +// Output is always directed to stderr, regardless of the configured log level. +// +// Parameters: +// - format: The format string (as in fmt.Printf). +// - args: The arguments for the format string. func (l *Logger) Errorf(format string, args ...interface{}) { l.logError.Printf(format, args...) } -// handleError writes an error message to both the HTTP response and the error log. -// It ensures consistent error handling across the middleware by logging the error -// and sending an appropriate HTTP response to the client. +// handleError logs an error message using the provided logger and sends an HTTP error +// response to the client with the specified message and status code. +// +// Parameters: +// - w: The http.ResponseWriter to send the error response to. +// - message: The error message string. +// - code: The HTTP status code for the response. +// - logger: The Logger instance to use for logging the error. func handleError(w http.ResponseWriter, message string, code int, logger *Logger) { logger.Error(message) http.Error(w, message, code)