mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Update documentation to the higher standards.
This commit is contained in:
+10
-2
@@ -2,8 +2,16 @@ package traefikoidc
|
||||
|
||||
import "time"
|
||||
|
||||
// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval.
|
||||
// It stops when a value is received on the stop channel.
|
||||
// autoCleanupRoutine periodically calls the provided cleanup function.
|
||||
// It starts a ticker with the given interval and executes the cleanup function
|
||||
// on each tick. The routine stops gracefully when a signal is received on the
|
||||
// stop channel. This is typically used for background cleanup tasks like
|
||||
// expiring cache entries.
|
||||
//
|
||||
// Parameters:
|
||||
// - interval: The time duration between cleanup calls.
|
||||
// - stop: A channel used to signal the routine to stop. Receiving any value will terminate the loop.
|
||||
// - cleanup: The function to call periodically for cleanup tasks.
|
||||
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -46,7 +46,9 @@ type Cache struct {
|
||||
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||
const DefaultMaxSize = 500
|
||||
|
||||
// NewCache creates a new empty cache instance that is ready for use.
|
||||
// NewCache creates a new empty cache instance with default settings.
|
||||
// It initializes the internal maps and list, sets the default maximum size,
|
||||
// and starts the automatic cleanup goroutine.
|
||||
func NewCache() *Cache {
|
||||
c := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
@@ -60,8 +62,12 @@ func NewCache() *Cache {
|
||||
return c
|
||||
}
|
||||
|
||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||
// It moves the item to the most recently used position.
|
||||
// Set adds or updates an item in the cache with the specified key, value, and expiration duration.
|
||||
// If the key already exists, its value and expiration time are updated, and it's moved
|
||||
// to the most recently used position in the LRU list.
|
||||
// If the key does not exist and the cache is full, the least recently used item is evicted
|
||||
// before adding the new item.
|
||||
// The expiration duration is relative to the time Set is called.
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -95,8 +101,11 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.elems[key] = elem
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache if it exists and hasn't expired.
|
||||
// Moving the accessed item to the most recently used position.
|
||||
// Get retrieves an item from the cache by its key.
|
||||
// If the item exists and has not expired, its value and true are returned.
|
||||
// Accessing an item moves it to the most recently used position in the LRU list.
|
||||
// If the item does not exist or has expired, nil and false are returned, and the
|
||||
// expired item is removed from the cache.
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -120,7 +129,9 @@ func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache.
|
||||
// Delete removes an item from the cache by its key.
|
||||
// If the key exists, the corresponding item is removed from the cache storage
|
||||
// and the LRU list.
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -128,8 +139,10 @@ func (c *Cache) Delete(key string) {
|
||||
c.removeItem(key)
|
||||
}
|
||||
|
||||
// Cleanup removes all expired items from the cache. This should be called periodically
|
||||
// to prevent memory bloat from expired entries.
|
||||
// Cleanup iterates through the cache and removes all items that have expired.
|
||||
// An item is considered expired if the current time is after its ExpiresAt timestamp.
|
||||
// This method is called automatically by the auto-cleanup goroutine, but can also
|
||||
// be called manually.
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -143,7 +156,11 @@ func (c *Cache) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used item from the cache.
|
||||
// evictOldest removes the least recently used (oldest) item from the cache.
|
||||
// It first attempts to find and remove an expired item from the front of the LRU list.
|
||||
// If no expired items are found at the front, it removes the absolute oldest item (front of the list).
|
||||
// This method is called internally by Set when the cache reaches its maximum size.
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) evictOldest() {
|
||||
now := time.Now()
|
||||
elem := c.order.Front()
|
||||
@@ -167,7 +184,9 @@ func (c *Cache) evictOldest() {
|
||||
}
|
||||
}
|
||||
|
||||
// removeItem removes an item from both the cache and the LRU tracking structures.
|
||||
// removeItem removes an item specified by the key from the cache's internal storage (items map)
|
||||
// and its corresponding entry from the LRU list (order list and elems map).
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) removeItem(key string) {
|
||||
delete(c.items, key)
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
@@ -176,12 +195,15 @@ func (c *Cache) removeItem(key string) {
|
||||
}
|
||||
}
|
||||
|
||||
// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items.
|
||||
// startAutoCleanup starts the background goroutine that automatically calls the Cleanup method
|
||||
// at the interval specified by c.autoCleanupInterval.
|
||||
// It uses the autoCleanupRoutine helper function.
|
||||
func (c *Cache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
// Close terminates the auto cleanup goroutine.
|
||||
// Close stops the automatic cleanup goroutine associated with this cache instance.
|
||||
// It should be called when the cache is no longer needed to prevent resource leaks.
|
||||
func (c *Cache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
+129
-38
@@ -15,10 +15,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce
|
||||
// for use in the OIDC authentication flow. The nonce is used to
|
||||
// prevent replay attacks by ensuring the token received matches
|
||||
// the authentication request.
|
||||
// generateNonce creates a cryptographically secure random string suitable for use as an OIDC nonce.
|
||||
// The nonce is used during the authentication flow to mitigate replay attacks by associating
|
||||
// the ID token with the specific authentication request.
|
||||
// It generates 32 random bytes and encodes them using base64 URL encoding.
|
||||
//
|
||||
// Returns:
|
||||
// - A base64 URL encoded random string (nonce).
|
||||
// - An error if the random byte generation fails.
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
@@ -28,9 +32,13 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier creates a cryptographically secure random string
|
||||
// for use as a PKCE code verifier. The code verifier must be between 43 and 128
|
||||
// characters long, per the PKCE spec (RFC 7636).
|
||||
// generateCodeVerifier creates a cryptographically secure random string suitable for use as a PKCE code verifier.
|
||||
// According to RFC 7636, the verifier should be a high-entropy string between 43 and 128 characters long.
|
||||
// This function generates 32 random bytes, resulting in a 43-character base64 URL encoded string.
|
||||
//
|
||||
// Returns:
|
||||
// - A base64 URL encoded random string (code verifier).
|
||||
// - An error if the random byte generation fails.
|
||||
func generateCodeVerifier() (string, error) {
|
||||
// Using 32 bytes (256 bits) will produce a 43 character base64url string
|
||||
verifierBytes := make([]byte, 32)
|
||||
@@ -41,8 +49,15 @@ func generateCodeVerifier() (string, error) {
|
||||
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
|
||||
}
|
||||
|
||||
// deriveCodeChallenge creates a code challenge from a code verifier
|
||||
// using the SHA-256 method as specified in the PKCE standard (RFC 7636).
|
||||
// deriveCodeChallenge computes the PKCE code challenge from a given code verifier.
|
||||
// It uses the S256 challenge method (SHA-256 hash followed by base64 URL encoding)
|
||||
// as defined in RFC 7636.
|
||||
//
|
||||
// Parameters:
|
||||
// - codeVerifier: The high-entropy string generated by generateCodeVerifier.
|
||||
//
|
||||
// Returns:
|
||||
// - The base64 URL encoded SHA-256 hash of the code verifier (code challenge).
|
||||
func deriveCodeChallenge(codeVerifier string) string {
|
||||
// Calculate SHA-256 hash of the code verifier
|
||||
hasher := sha256.New()
|
||||
@@ -73,14 +88,22 @@ type TokenResponse struct {
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
|
||||
// It supports both authorization code and refresh token grant types.
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint.
|
||||
// It handles both the "authorization_code" grant type (exchanging an authorization code for tokens)
|
||||
// and the "refresh_token" grant type (using a refresh token to obtain new tokens).
|
||||
// It includes necessary parameters like client credentials and handles PKCE verification if applicable.
|
||||
// The function follows redirects and handles potential errors during the exchange.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
||||
// - codeOrToken: Either the authorization code or refresh token
|
||||
// - redirectURL: The callback URL for authorization code grant
|
||||
// - codeVerifier: Optional PKCE code verifier for authorization code grant
|
||||
// - ctx: The context for the outgoing HTTP request.
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token").
|
||||
// - codeOrToken: The authorization code (for "authorization_code" grant) or the refresh token (for "refresh_token" grant).
|
||||
// - redirectURL: The redirect URI that was used in the initial authorization request (required for "authorization_code" grant).
|
||||
// - codeVerifier: The PKCE code verifier (required for "authorization_code" grant if PKCE was used).
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the obtained tokens (ID, access, refresh).
|
||||
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant).
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
@@ -140,8 +163,16 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
|
||||
// This is used to refresh access tokens before they expire.
|
||||
// getNewTokenWithRefreshToken uses a refresh token to obtain a new set of tokens (ID, access, refresh)
|
||||
// from the OIDC provider's token endpoint. It wraps the exchangeTokens function with the
|
||||
// "refresh_token" grant type.
|
||||
//
|
||||
// Parameters:
|
||||
// - refreshToken: The refresh token previously obtained during authentication or a prior refresh.
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the newly obtained tokens.
|
||||
// - An error if the refresh operation fails.
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
|
||||
@@ -153,8 +184,17 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// extractClaims parses a JWT token and extracts its claims.
|
||||
// It handles base64url decoding and JSON parsing of the token payload.
|
||||
// extractClaims decodes the payload (claims set) part of a JWT string.
|
||||
// It splits the JWT into its three parts, base64 URL decodes the second part (payload),
|
||||
// and unmarshals the resulting JSON into a map.
|
||||
// Note: This function does *not* validate the token's signature or claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT string.
|
||||
//
|
||||
// Returns:
|
||||
// - A map representing the JSON claims extracted from the token payload.
|
||||
// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails.
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -182,21 +222,36 @@ type TokenCache struct {
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewTokenCache creates a new TokenCache instance.
|
||||
// NewTokenCache creates and initializes a new TokenCache.
|
||||
// It internally creates a new generic Cache instance for storage.
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
cache: NewCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a token's claims in the cache with an expiration time.
|
||||
// Set stores the claims associated with a specific token string in the cache.
|
||||
// It prefixes the token string to avoid potential collisions with other cache types
|
||||
// and sets the provided expiration duration.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string (used as the key).
|
||||
// - claims: The map of claims associated with the token.
|
||||
// - expiration: The duration for which the cache entry should be valid.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
|
||||
// Get retrieves a token's claims from the cache.
|
||||
// Returns the claims and a boolean indicating if the token was found.
|
||||
// Get retrieves the cached claims for a given token string.
|
||||
// It prefixes the token string before querying the underlying cache.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to look up.
|
||||
//
|
||||
// Returns:
|
||||
// - The cached claims map if found and valid.
|
||||
// - A boolean indicating whether the token was found in the cache (true if found, false otherwise).
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
@@ -207,20 +262,34 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// Delete removes a token from the cache.
|
||||
// Delete removes the cached entry for a specific token string.
|
||||
// It prefixes the token string before calling the underlying cache's Delete method.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to remove from the cache.
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the cache.
|
||||
// Cleanup triggers the cleanup process for the underlying generic cache,
|
||||
// removing expired token entries.
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.cache.Cleanup()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges an authorization code for tokens.
|
||||
// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration.
|
||||
// The code verifier is only included in the token request if PKCE is enabled.
|
||||
// exchangeCodeForToken is a convenience function that wraps exchangeTokens specifically
|
||||
// for the "authorization_code" grant type. It handles the conditional inclusion of the
|
||||
// PKCE code verifier based on the middleware's configuration (t.enablePKCE).
|
||||
//
|
||||
// Parameters:
|
||||
// - code: The authorization code received from the OIDC provider.
|
||||
// - redirectURL: The redirect URI used in the initial authorization request.
|
||||
// - codeVerifier: The PKCE code verifier stored in the session (if PKCE is enabled).
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the obtained tokens.
|
||||
// - An error if the code exchange fails.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -237,8 +306,15 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, code
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// createStringMap creates a map from a slice of strings.
|
||||
// Used for efficient lookups in allowed domains and roles.
|
||||
// createStringMap converts a slice of strings into a map[string]struct{} (a set).
|
||||
// This is useful for creating efficient lookups (O(1) average time complexity)
|
||||
// for checking the presence of items like allowed domains, roles, or groups.
|
||||
//
|
||||
// Parameters:
|
||||
// - keys: A slice of strings to be added to the set.
|
||||
//
|
||||
// Returns:
|
||||
// - A map where the keys are the strings from the input slice and the values are empty structs.
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
@@ -247,9 +323,17 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout manages the OIDC logout process.
|
||||
// It clears the session and redirects either to the OIDC provider's
|
||||
// end session endpoint (if available) or to the configured post-logout URL.
|
||||
// handleLogout processes requests to the configured logout path.
|
||||
// It performs the following steps:
|
||||
// 1. Retrieves the current user session.
|
||||
// 2. Gets the access token (ID token hint) from the session.
|
||||
// 3. Clears all authentication-related data from the session cookies.
|
||||
// 4. Determines the final post-logout redirect URI.
|
||||
// 5. If an OIDC end_session_endpoint is configured and an ID token hint is available,
|
||||
// it builds the OIDC logout URL and redirects the user agent to the provider for logout.
|
||||
// 6. Otherwise, it redirects the user agent directly to the post-logout redirect URI.
|
||||
//
|
||||
// It handles potential errors during session retrieval or clearing.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -291,11 +375,18 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
|
||||
// BuildLogoutURL constructs the URL for redirecting the user agent to the OIDC provider's
|
||||
// end_session_endpoint, including the required id_token_hint and optional
|
||||
// post_logout_redirect_uri parameters as query arguments.
|
||||
//
|
||||
// Parameters:
|
||||
// - endSessionURL: The OIDC provider's end session endpoint
|
||||
// - idToken: The ID token to be invalidated
|
||||
// - postLogoutRedirectURI: Where to redirect after logout completes
|
||||
// - endSessionURL: The URL of the OIDC provider's end session endpoint.
|
||||
// - idToken: The ID token previously issued to the user (used as id_token_hint).
|
||||
// - postLogoutRedirectURI: The optional URI where the provider should redirect the user agent after logout.
|
||||
//
|
||||
// Returns:
|
||||
// - The fully constructed logout URL string.
|
||||
// - An error if the provided endSessionURL is invalid.
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -45,6 +45,21 @@ type JWKCacheInterface interface {
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
// GetJWKS retrieves the JSON Web Key Set (JWKS) from the cache or fetches it from the provider.
|
||||
// It first checks if a valid, non-expired JWKS is present in the cache. If so, it returns the cached version.
|
||||
// Otherwise, it attempts to fetch the JWKS from the specified jwksURL using the provided httpClient.
|
||||
// If the fetch is successful, the JWKS is stored in the cache with an expiration time based on CacheLifetime
|
||||
// (defaulting to 1 hour if not set) and returned.
|
||||
// This method uses double-checked locking to minimize contention when the cache needs refreshing.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request if fetching is required.
|
||||
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
|
||||
// - httpClient: The HTTP client to use for fetching the JWKS.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the JWKSet containing the keys.
|
||||
// - An error if fetching fails or the response cannot be decoded.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
@@ -74,6 +89,8 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// Cleanup removes the cached JWKS if it has expired.
|
||||
// This is intended to be called periodically to ensure stale JWKS data is cleared.
|
||||
func (c *JWKCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -84,6 +101,17 @@ func (c *JWKCache) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// fetchJWKS retrieves the JSON Web Key Set (JWKS) from the specified URL.
|
||||
// It uses the provided context and HTTP client to make the request.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request.
|
||||
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
|
||||
// - httpClient: The HTTP client to use for the request.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the fetched JWKSet.
|
||||
// - An error if the request fails, the status code is not OK, or the response body cannot be decoded.
|
||||
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Create a request with context to enforce timeout
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
@@ -109,6 +137,16 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JWK (JSON Web Key) object into PEM (Privacy-Enhanced Mail) format.
|
||||
// It selects the appropriate conversion function based on the JWK's key type ("kty").
|
||||
// Currently supports "RSA" and "EC" key types.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the JWK object to convert.
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the public key in PEM format.
|
||||
// - An error if the key type is unsupported or conversion fails.
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
@@ -124,6 +162,17 @@ var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JWK into PEM format.
|
||||
// It decodes the modulus (n) and exponent (e) from base64 URL encoding,
|
||||
// constructs an rsa.PublicKey, marshals it into PKIX format, and then
|
||||
// encodes it as a PEM block.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the RSA JWK object (must have "kty": "RSA").
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the RSA public key in PEM format.
|
||||
// - An error if decoding parameters fails or key marshaling fails.
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
@@ -155,6 +204,18 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC (Elliptic Curve) JWK into PEM format.
|
||||
// It decodes the X and Y coordinates from base64 URL encoding, determines the
|
||||
// elliptic curve based on the "crv" parameter (P-256, P-384, P-521),
|
||||
// constructs an ecdsa.PublicKey, marshals it into PKIX format, and then
|
||||
// encodes it as a PEM block.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the EC JWK object (must have "kty": "EC").
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the EC public key in PEM format.
|
||||
// - An error if decoding parameters fails, the curve is unsupported, or key marshaling fails.
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
|
||||
@@ -20,6 +20,10 @@ var (
|
||||
replayCache = make(map[string]time.Time)
|
||||
)
|
||||
|
||||
// cleanupReplayCache iterates through the replay cache and removes entries
|
||||
// whose expiration time is before the current time. This function should be
|
||||
// called periodically to prevent the cache from growing indefinitely.
|
||||
// It acquires a mutex to ensure thread safety during cleanup.
|
||||
func cleanupReplayCache() {
|
||||
now := time.Now()
|
||||
for token, expiry := range replayCache {
|
||||
@@ -48,7 +52,18 @@ type JWT struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct.
|
||||
// parseJWT decodes a raw JWT string into its constituent parts: header, claims, and signature.
|
||||
// It splits the token string by '.', decodes each part using base64 URL decoding,
|
||||
// and unmarshals the header and claims JSON into maps. The raw signature bytes are stored.
|
||||
// It performs basic format validation (expecting 3 parts).
|
||||
// Note: This function does *not* validate the signature or the claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT string.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a JWT struct containing the decoded parts.
|
||||
// - An error if the token format is invalid or decoding/unmarshaling fails.
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -84,8 +99,24 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// Verify performs standard claim validation on the JWT according to RFC 7519.
|
||||
// It checks the following:
|
||||
// - Algorithm ('alg') is supported.
|
||||
// - Issuer ('iss') matches the expected issuerURL.
|
||||
// - Audience ('aud') contains the expected clientID.
|
||||
// - Expiration time ('exp') is in the future (within tolerance).
|
||||
// - Issued at time ('iat') is in the past (within tolerance).
|
||||
// - Not before time ('nbf'), if present, is in the past (within tolerance).
|
||||
// - Subject ('sub') claim exists and is not empty.
|
||||
// - JWT ID ('jti'), if present, is checked against a replay cache to prevent token reuse.
|
||||
//
|
||||
// Parameters:
|
||||
// - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com").
|
||||
// - clientID: The expected audience value (the client ID of this application).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if all standard claims are valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
@@ -175,6 +206,16 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience checks if the expected audience is present in the token's 'aud' claim.
|
||||
// The 'aud' claim can be a single string or an array of strings.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenAudience: The 'aud' claim value extracted from the token (can be string or []interface{}).
|
||||
// - expectedAudience: The audience value expected for this application (client ID).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the expected audience is found.
|
||||
// - An error if the claim type is invalid or the expected audience is not present.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
@@ -198,6 +239,15 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer checks if the token's 'iss' claim matches the expected issuer URL.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenIssuer: The 'iss' claim value from the token.
|
||||
// - expectedIssuer: The expected issuer URL configured for the OIDC provider.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the issuers match.
|
||||
// - An error if the issuers do not match.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
||||
@@ -205,57 +255,76 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyTimeConstraint is a generic function to verify time-based claims
|
||||
// verifyTimeConstraint checks time-based claims ('exp', 'iat', 'nbf') against the current time,
|
||||
// allowing for configurable clock skew. It uses different tolerances for past and future checks.
|
||||
//
|
||||
// Parameters:
|
||||
// - unixTime: The timestamp value from the claim (as a float64 Unix time).
|
||||
// - claimName: The name of the claim being verified ("exp", "iat", "nbf").
|
||||
// - future: A boolean indicating the direction of the check (true for 'exp', false for 'iat'/'nbf').
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the time constraint is met within the allowed tolerance.
|
||||
// - An error describing the failure (e.g., "token has expired", "token used before issued").
|
||||
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now().Truncate(time.Second)
|
||||
now := time.Now() // Use current time without truncation
|
||||
|
||||
var skewedNow time.Time
|
||||
if future {
|
||||
// For expiration (future=true), add skew to now (making now later)
|
||||
// Use the larger tolerance for future checks (exp)
|
||||
skewedNow = now.Add(ClockSkewToleranceFuture)
|
||||
} else {
|
||||
// For iat/nbf (future=false), subtract skew from now (making now earlier)
|
||||
// Use the smaller, specific tolerance for past checks (iat, nbf)
|
||||
skewedNow = now.Add(-ClockSkewTolerancePast) // Subtract the past tolerance
|
||||
}
|
||||
|
||||
if claimTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For expiration: if skewedNow (later) is after expiration, token expired
|
||||
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
|
||||
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
|
||||
var reason string
|
||||
if future {
|
||||
reason = "has expired"
|
||||
} else {
|
||||
var err error
|
||||
if future { // 'exp' check
|
||||
// Token is expired if Now is after (ClaimTime + FutureTolerance)
|
||||
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
|
||||
if now.After(allowedExpiry) {
|
||||
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
|
||||
}
|
||||
} else { // 'iat' or 'nbf' check
|
||||
// Token is invalid if Now is before (ClaimTime - PastTolerance)
|
||||
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
|
||||
if now.Before(allowedStart) {
|
||||
reason := "not yet valid"
|
||||
if claimName == "iat" {
|
||||
reason = "used before issued"
|
||||
} else {
|
||||
reason = "not yet valid"
|
||||
}
|
||||
err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC())
|
||||
}
|
||||
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// verifyExpiration checks the 'exp' (Expiration Time) claim.
|
||||
// It calls verifyTimeConstraint with future=true.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
return verifyTimeConstraint(expiration, "exp", true)
|
||||
}
|
||||
|
||||
// verifyIssuedAt checks the 'iat' (Issued At) claim.
|
||||
// It calls verifyTimeConstraint with future=false.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
return verifyTimeConstraint(issuedAt, "iat", false)
|
||||
}
|
||||
|
||||
// verifyNotBefore checks the 'nbf' (Not Before) claim.
|
||||
// It calls verifyTimeConstraint with future=false.
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
return verifyTimeConstraint(notBefore, "nbf", false)
|
||||
}
|
||||
|
||||
// verifySignature validates the JWT's signature using the provided public key.
|
||||
// It parses the public key from PEM format, selects the appropriate hashing algorithm
|
||||
// based on the 'alg' parameter (SHA256/384/512), hashes the token's signing input
|
||||
// (header + "." + payload), and then verifies the signature against the hash using
|
||||
// the corresponding RSA (PKCS1v15 or PSS) or ECDSA verification method.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw, complete JWT string.
|
||||
// - publicKeyPEM: The public key corresponding to the private key used for signing, in PEM format.
|
||||
// - alg: The algorithm specified in the JWT header (e.g., "RS256", "ES384").
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the signature is valid.
|
||||
// - An error if the token format is invalid, decoding fails, key parsing fails,
|
||||
// the algorithm is unsupported, or the signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
|
||||
@@ -17,7 +17,13 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// createDefaultHTTPClient creates an HTTP client with optimized settings for OIDC
|
||||
// createDefaultHTTPClient creates a new http.Client with settings optimized for OIDC communication.
|
||||
// It configures the transport with specific timeouts (dial, keepalive, TLS handshake, idle connection),
|
||||
// connection limits (max idle, max per host), enables HTTP/2, and sets a default request timeout.
|
||||
// It also configures redirect handling to follow redirects up to a limit.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the configured http.Client.
|
||||
func createDefaultHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -128,16 +134,20 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon": {},
|
||||
}
|
||||
|
||||
// VerifyToken implements the TokenVerifier interface to verify an OIDC token.
|
||||
// It performs a complete verification process including:
|
||||
// 1. Checking the token cache to avoid redundant verifications
|
||||
// 2. Performing rate limiting and blacklist checks
|
||||
// 3. Parsing the JWT structure
|
||||
// 4. Verifying the JWT signature against the JWKS from the provider
|
||||
// 5. Validating standard JWT claims (iss, aud, exp, etc.)
|
||||
// 6. Caching the verified token for future requests
|
||||
// VerifyToken implements the TokenVerifier interface. It performs a comprehensive validation of an ID token:
|
||||
// 1. Checks the token cache; returns nil immediately if a valid cached entry exists.
|
||||
// 2. Performs pre-verification checks (rate limiting, blacklist).
|
||||
// 3. Parses the raw token string into a JWT struct.
|
||||
// 4. Verifies the JWT signature and standard claims (iss, aud, exp, iat, nbf, sub) using VerifyJWTSignatureAndClaims.
|
||||
// 5. If verification succeeds, caches the token claims until the token's expiration time.
|
||||
// 6. If verification succeeds and the token has a JTI claim, adds the JTI to the blacklist cache to prevent replay attacks.
|
||||
//
|
||||
// Returns nil if the token is valid, or an error describing the validation failure.
|
||||
// Parameters:
|
||||
// - token: The raw ID token string to verify.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the token is valid according to all checks.
|
||||
// - An error describing the reason for validation failure (e.g., rate limit, blacklisted, parsing error, signature error, claim error).
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
// Check cache first
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
@@ -192,7 +202,16 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// performPreVerificationChecks performs rate limiting and blacklist checks
|
||||
// performPreVerificationChecks executes preliminary checks before attempting full token validation.
|
||||
// It enforces rate limiting using the configured limiter and checks if the raw token string
|
||||
// or its JTI (if extractable) exists in the blacklist cache.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string being verified.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if all pre-verification checks pass.
|
||||
// - An error if the rate limit is exceeded or the token/JTI is blacklisted.
|
||||
func (t *TraefikOidc) performPreVerificationChecks(token string) error {
|
||||
// Enforce rate limiting
|
||||
if !t.limiter.Allow() {
|
||||
@@ -218,7 +237,13 @@ func (t *TraefikOidc) performPreVerificationChecks(token string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// cacheVerifiedToken caches a verified token until its expiration time
|
||||
// cacheVerifiedToken adds the claims of a successfully verified token to the token cache.
|
||||
// It calculates the remaining duration until the token's 'exp' claim and uses that
|
||||
// duration for the cache entry's lifetime.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string (used as the cache key).
|
||||
// - claims: The map of claims extracted from the verified token.
|
||||
func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interface{}) {
|
||||
expirationTime := time.Unix(int64(claims["exp"].(float64)), 0)
|
||||
now := time.Now()
|
||||
@@ -226,7 +251,18 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
|
||||
t.tokenCache.Set(token, claims, duration)
|
||||
}
|
||||
|
||||
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
|
||||
// VerifyJWTSignatureAndClaims implements the JWTVerifier interface. It verifies the signature
|
||||
// of a parsed JWT against the provider's public keys obtained from the JWKS endpoint,
|
||||
// and then validates the standard JWT claims (iss, aud, exp, iat, nbf, sub, jti replay).
|
||||
//
|
||||
// Parameters:
|
||||
// - jwt: A pointer to the parsed JWT struct containing header and claims.
|
||||
// - token: The original raw token string (used for signature verification).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if both the signature and all standard claims are valid.
|
||||
// - An error describing the validation failure (e.g., failed to get JWKS, missing kid/alg,
|
||||
// no matching key, signature verification failed, standard claim validation failed).
|
||||
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
t.logger.Debugf("Verifying JWT signature and claims")
|
||||
|
||||
@@ -277,24 +313,28 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new instance of the OIDC middleware.
|
||||
// This is the main entry point for the middleware and is called by Traefik when loading the plugin.
|
||||
// It initializes all components needed for OIDC authentication:
|
||||
// - Session management for storing user state
|
||||
// - Token caching and blacklisting
|
||||
// - JWK caching for signature verification
|
||||
// - Rate limiting to prevent abuse
|
||||
// - Metadata discovery for OIDC provider endpoints
|
||||
// New is the constructor for the TraefikOidc middleware plugin.
|
||||
// It is called by Traefik during plugin initialization. It performs the following steps:
|
||||
// 1. Creates a default configuration if none is provided.
|
||||
// 2. Validates the session encryption key length.
|
||||
// 3. Initializes the logger based on the configured log level.
|
||||
// 4. Sets up the HTTP client (using defaults if none provided in config).
|
||||
// 5. Creates the main TraefikOidc struct, populating fields from the config
|
||||
// (paths, client details, PKCE/HTTPS flags, scopes, rate limiter, caches, allowed lists).
|
||||
// 6. Initializes the SessionManager.
|
||||
// 7. Sets up internal function pointers/interfaces (extractClaimsFunc, initiateAuthenticationFunc, tokenVerifier, jwtVerifier, tokenExchanger).
|
||||
// 8. Adds default excluded URLs.
|
||||
// 9. Starts background goroutines for token cache cleanup and OIDC provider metadata initialization/refresh.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for initialization operations
|
||||
// - next: The next handler in the middleware chain
|
||||
// - config: Configuration options for the middleware
|
||||
// - name: Identifier for this middleware instance
|
||||
// - ctx: The context provided by Traefik for initialization.
|
||||
// - next: The next http.Handler in the Traefik middleware chain.
|
||||
// - config: The plugin configuration provided by the user in Traefik static/dynamic configuration.
|
||||
// - name: The name assigned to this middleware instance by Traefik.
|
||||
//
|
||||
// Returns:
|
||||
// - An http.Handler that implements the middleware
|
||||
// - An error if initialization fails
|
||||
// - An http.Handler (the TraefikOidc instance itself, which implements ServeHTTP).
|
||||
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
if config == nil {
|
||||
config = CreateConfig()
|
||||
@@ -386,7 +426,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// initializeMetadata discovers and initializes the provider metadata
|
||||
// initializeMetadata asynchronously fetches and caches the OIDC provider metadata.
|
||||
// It uses the MetadataCache to retrieve potentially cached data or fetch fresh data
|
||||
// via discoverProviderMetadata. On successful retrieval, it updates the middleware's
|
||||
// endpoint URLs (auth, token, jwks, etc.), starts the periodic metadata refresh goroutine,
|
||||
// and signals completion by closing the initComplete channel. If fetching fails initially,
|
||||
// it logs an error and the middleware might remain uninitialized until a successful refresh.
|
||||
//
|
||||
// Parameters:
|
||||
// - providerURL: The base URL of the OIDC provider.
|
||||
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.logger.Debug("Starting provider metadata discovery")
|
||||
|
||||
@@ -412,7 +460,12 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.logger.Error("Received nil metadata")
|
||||
}
|
||||
|
||||
// updateMetadataEndpoints updates the middleware with metadata endpoints
|
||||
// updateMetadataEndpoints updates the relevant endpoint URL fields (jwksURL, authURL, tokenURL, etc.)
|
||||
// within the TraefikOidc instance based on the discovered provider metadata.
|
||||
// This is called after successfully fetching or refreshing the metadata.
|
||||
//
|
||||
// Parameters:
|
||||
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
|
||||
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
@@ -422,7 +475,13 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
}
|
||||
|
||||
// startMetadataRefresh periodically refreshes the OIDC metadata
|
||||
// startMetadataRefresh starts a background goroutine that periodically attempts to refresh
|
||||
// the OIDC provider metadata by calling GetMetadata on the metadataCache.
|
||||
// It runs on a fixed ticker (currently 1 hour). Successful refreshes update the
|
||||
// middleware's endpoint URLs via updateMetadataEndpoints. Fetch errors are logged.
|
||||
//
|
||||
// Parameters:
|
||||
// - providerURL: The base URL of the OIDC provider, used for subsequent refresh attempts.
|
||||
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
@@ -442,7 +501,19 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
}
|
||||
}
|
||||
|
||||
// discoverProviderMetadata fetches the OIDC provider metadata
|
||||
// discoverProviderMetadata attempts to fetch the OIDC provider's configuration from its
|
||||
// well-known discovery endpoint (".well-known/openid-configuration").
|
||||
// It implements an exponential backoff retry mechanism in case of transient network errors
|
||||
// or provider unavailability during startup.
|
||||
//
|
||||
// Parameters:
|
||||
// - providerURL: The base URL of the OIDC provider.
|
||||
// - httpClient: The HTTP client to use for the request.
|
||||
// - l: The logger instance for recording retries and errors.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the fetched ProviderMetadata struct.
|
||||
// - An error if fetching fails after all retries or if a timeout is exceeded.
|
||||
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
|
||||
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
@@ -481,7 +552,16 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
|
||||
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
|
||||
}
|
||||
|
||||
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
|
||||
// fetchMetadata performs a single attempt to fetch and decode the OIDC provider metadata
|
||||
// from the specified well-known configuration URL.
|
||||
//
|
||||
// Parameters:
|
||||
// - wellKnownURL: The full URL to the ".well-known/openid-configuration" endpoint.
|
||||
// - httpClient: The HTTP client to use for the GET request.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the decoded ProviderMetadata struct.
|
||||
// - An error if the GET request fails, the status code is not 200 OK, or JSON decoding fails.
|
||||
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
|
||||
resp, err := httpClient.Get(wellKnownURL)
|
||||
if err != nil {
|
||||
@@ -504,18 +584,22 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// ServeHTTP is the main handler for the middleware that processes all HTTP requests.
|
||||
// It implements the http.Handler interface and performs the following operations:
|
||||
// 1. Waits for OIDC provider metadata initialization to complete
|
||||
// 2. Checks if the requested URL is in the excluded list (bypassing authentication)
|
||||
// 3. Retrieves or creates a user session
|
||||
// 4. Handles special paths like callback and logout URLs
|
||||
// 5. Verifies authentication status and token validity
|
||||
// 6. Refreshes tokens that are about to expire
|
||||
// 7. Validates user email domains, roles, and groups against configured restrictions
|
||||
// 8. Sets appropriate headers for downstream services
|
||||
// 9. Applies security headers to responses
|
||||
// 10. Forwards the authenticated request to the next handler
|
||||
// ServeHTTP is the main entry point for incoming requests to the middleware.
|
||||
// It orchestrates the OIDC authentication flow:
|
||||
// 1. Waits for initial OIDC metadata discovery to complete (with timeout).
|
||||
// 2. Checks if the request path is excluded from authentication.
|
||||
// 3. Checks if the request is for Server-Sent Events and bypasses if so.
|
||||
// 4. Retrieves the user's session; initiates authentication if the session is invalid/missing.
|
||||
// 5. Handles specific paths for OIDC callback (/callback) and logout (/logout).
|
||||
// 6. Checks the user's authentication status using isUserAuthenticated (verifies token, checks expiry).
|
||||
// 7. If the token is expired, handles it (initiates re-auth).
|
||||
// 8. If the user is not authenticated, initiates authentication.
|
||||
// 9. If the token needs proactive refresh (nearing expiry), attempts refreshToken. Handles refresh failure
|
||||
// by returning 401 for API clients or initiating re-auth for browsers.
|
||||
// 10. If authenticated and token is valid, performs authorization checks (allowed domain, roles/groups).
|
||||
// 11. If authorized, sets user/token information in request headers (X-Forwarded-User, X-Auth-Request-*)
|
||||
// and adds security headers (X-Frame-Options, etc.) to the response.
|
||||
// 12. Forwards the request to the next handler in the chain.
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
@@ -698,8 +782,17 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
// handleExpiredToken manages token expiration by clearing the session
|
||||
// and initiating a new authentication flow.
|
||||
// handleExpiredToken is called when a user's session contains an expired token or
|
||||
// when a token refresh attempt fails for a browser client.
|
||||
// It clears the authentication-related data (tokens, email, authenticated flag) from the session,
|
||||
// saves the cleared session, and then initiates a new authentication flow by calling
|
||||
// defaultInitiateAuthentication, redirecting the user to the OIDC provider.
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The HTTP request.
|
||||
// - session: The user's session data containing the expired token information.
|
||||
// - redirectURL: The callback URL to be used in the new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear authentication data but preserve CSRF state
|
||||
session.SetAuthenticated(false)
|
||||
@@ -717,9 +810,27 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback processes the authentication callback from the OIDC provider.
|
||||
// It validates the callback parameters, exchanges the authorization code for
|
||||
// tokens, verifies the tokens, and establishes the user's session.
|
||||
// handleCallback handles the request received at the OIDC callback URL (redirect_uri).
|
||||
// It performs the following steps:
|
||||
// 1. Retrieves the user session associated with the callback request.
|
||||
// 2. Checks for error parameters returned by the OIDC provider.
|
||||
// 3. Validates the 'state' parameter against the CSRF token stored in the session.
|
||||
// 4. Extracts the authorization 'code' from the query parameters.
|
||||
// 5. Retrieves the PKCE 'code_verifier' from the session (if PKCE is enabled).
|
||||
// 6. Exchanges the authorization code for tokens using the TokenExchanger interface.
|
||||
// 7. Verifies the received ID token's signature and standard claims using VerifyToken.
|
||||
// 8. Extracts claims from the verified ID token.
|
||||
// 9. Verifies the 'nonce' claim against the nonce stored in the session.
|
||||
// 10. Validates the user's email domain against the allowed list.
|
||||
// 11. If all checks pass, updates the session with authentication details (status, email, tokens).
|
||||
// 12. Saves the updated session.
|
||||
// 13. Redirects the user back to their original requested path (stored in session) or the root path.
|
||||
// If any step fails, it sends an appropriate error response using sendErrorResponse.
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The incoming HTTP request to the callback URL.
|
||||
// - redirectURL: The fully qualified callback URL (used in the token exchange request).
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -846,7 +957,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// determineExcludedURL checks if the current request URL is in the excluded list
|
||||
// determineExcludedURL checks if the provided request path matches any of the configured excluded URL prefixes.
|
||||
//
|
||||
// Parameters:
|
||||
// - currentRequest: The path part of the incoming request URL.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the path starts with any of the prefixes in the t.excludedURLs map.
|
||||
// - false otherwise.
|
||||
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
for excludedURL := range t.excludedURLs {
|
||||
if strings.HasPrefix(currentRequest, excludedURL) {
|
||||
@@ -858,7 +976,15 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// determineScheme determines the scheme (http or https) of the request
|
||||
// determineScheme determines the request scheme (http or https).
|
||||
// It prioritizes the X-Forwarded-Proto header if present, otherwise checks
|
||||
// the TLS property of the request. Defaults to "http".
|
||||
//
|
||||
// Parameters:
|
||||
// - req: The incoming HTTP request.
|
||||
//
|
||||
// Returns:
|
||||
// - "https" or "http".
|
||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
@@ -869,7 +995,14 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
return "http"
|
||||
}
|
||||
|
||||
// determineHost determines the host of the request
|
||||
// determineHost determines the request host.
|
||||
// It prioritizes the X-Forwarded-Host header if present, otherwise uses the req.Host value.
|
||||
//
|
||||
// Parameters:
|
||||
// - req: The incoming HTTP request.
|
||||
//
|
||||
// Returns:
|
||||
// - The determined host string (e.g., "example.com:8080").
|
||||
func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
return host
|
||||
@@ -877,17 +1010,18 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
return req.Host
|
||||
}
|
||||
|
||||
// isUserAuthenticated checks if the user is authenticated by validating their session and token.
|
||||
// It performs a comprehensive check of the authentication state including:
|
||||
// 1. Verifying the session's authenticated flag
|
||||
// 2. Checking for the presence of an access token
|
||||
// 3. Validating the token's signature and claims
|
||||
// 4. Checking the token's expiration time
|
||||
// isUserAuthenticated checks the authentication status based on the provided session data.
|
||||
// It verifies the session's authenticated flag, the presence and validity of the access token (ID token),
|
||||
// including signature and standard claims (using VerifyJWTSignatureAndClaims). It also checks if the
|
||||
// token is within the configured refreshGracePeriod before its actual expiration.
|
||||
//
|
||||
// Returns three boolean values:
|
||||
// - authenticated: Whether the user is currently authenticated
|
||||
// - needsRefresh: Whether the token is valid but will expire soon (within grace period)
|
||||
// - expired: Whether the token has expired or is otherwise invalid
|
||||
// Parameters:
|
||||
// - session: The SessionData object for the current user.
|
||||
//
|
||||
// Returns:
|
||||
// - authenticated (bool): True if the session is marked authenticated and the token is present and valid (signature/claims ok, not expired beyond grace).
|
||||
// - needsRefresh (bool): True if the token is valid but nearing expiration (within refreshGracePeriod) OR if VerifyJWTSignatureAndClaims failed specifically due to expiration (meaning refresh might be possible).
|
||||
// - expired (bool): True if the session is unauthenticated, the token is missing, or the token verification failed for reasons other than nearing/actual expiration (e.g., invalid signature, invalid claims).
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debug("User is not authenticated according to session")
|
||||
@@ -945,19 +1079,18 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// defaultInitiateAuthentication initiates the OIDC authentication process.
|
||||
// This function prepares and starts a new authentication flow by:
|
||||
// 1. Generating security tokens (CSRF token and nonce) to prevent attacks
|
||||
// 2. Clearing any existing session data to avoid state conflicts
|
||||
// 3. Storing the original request path to redirect back after authentication
|
||||
// 4. Building the authorization URL with all required OIDC parameters
|
||||
// 5. Redirecting the user to the OIDC provider's authorization endpoint
|
||||
// defaultInitiateAuthentication handles the process of starting an OIDC authentication flow.
|
||||
// It generates necessary security values (CSRF token, nonce, PKCE verifier/challenge if enabled),
|
||||
// clears any potentially stale data from the current session, stores the new security values
|
||||
// and the original request URI in the session, saves the session (setting cookies),
|
||||
// builds the OIDC authorization endpoint URL with required parameters, and finally
|
||||
// redirects the user's browser to that URL.
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer for sending the redirect
|
||||
// - req: The original HTTP request that triggered authentication
|
||||
// - session: The user's session data for storing authentication state
|
||||
// - redirectURL: The callback URL where the OIDC provider will redirect after authentication
|
||||
// - rw: The HTTP response writer used to send the redirect response.
|
||||
// - req: The original incoming HTTP request that requires authentication.
|
||||
// - session: The user's SessionData object (potentially new or cleared).
|
||||
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Generate CSRF token and nonce
|
||||
csrfToken := uuid.NewString()
|
||||
@@ -1007,15 +1140,33 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// verifyToken verifies the token using the token verifier interface.
|
||||
// This function delegates to the configured token verifier implementation,
|
||||
// which by default is the TraefikOidc instance itself (implementing the VerifyToken method).
|
||||
// This design allows for easy mocking in tests and potential future extension.
|
||||
// verifyToken is a wrapper method that calls the VerifyToken method of the configured
|
||||
// TokenVerifier interface (which defaults to the TraefikOidc instance itself).
|
||||
// This primarily exists to facilitate testing and potential future extensions where
|
||||
// token verification logic might be delegated differently.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to verify.
|
||||
//
|
||||
// Returns:
|
||||
// - The result of calling t.tokenVerifier.VerifyToken(token).
|
||||
func (t *TraefikOidc) verifyToken(token string) error {
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
|
||||
// buildAuthURL constructs the authentication URL with optional PKCE support
|
||||
// buildAuthURL constructs the OIDC authorization endpoint URL with all necessary query parameters
|
||||
// for initiating the authorization code flow. It includes client_id, response_type, redirect_uri,
|
||||
// state, nonce, and optionally PKCE parameters (code_challenge, code_challenge_method) if enabled
|
||||
// and a challenge is provided. It also includes configured scopes.
|
||||
//
|
||||
// Parameters:
|
||||
// - redirectURL: The callback URL (redirect_uri).
|
||||
// - state: The CSRF token.
|
||||
// - nonce: The OIDC nonce.
|
||||
// - codeChallenge: The PKCE code challenge (can be empty if PKCE is disabled or not used).
|
||||
//
|
||||
// Returns:
|
||||
// - The fully constructed authorization URL string.
|
||||
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", t.clientID)
|
||||
@@ -1037,7 +1188,16 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
return t.buildURLWithParams(t.authURL, params)
|
||||
}
|
||||
|
||||
// buildURLWithParams ensures a URL is absolute and appends query parameters
|
||||
// buildURLWithParams takes a base URL and query parameters and constructs a full URL string.
|
||||
// If the baseURL is relative (doesn't start with http/https), it prepends the scheme and host
|
||||
// from the configured issuerURL. It then appends the encoded query parameters.
|
||||
//
|
||||
// Parameters:
|
||||
// - baseURL: The base URL (can be absolute or relative to the issuer).
|
||||
// - params: A url.Values map containing the query parameters to append.
|
||||
//
|
||||
// Returns:
|
||||
// - The fully constructed URL string with appended query parameters.
|
||||
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
// Ensure URL is absolute
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
@@ -1054,7 +1214,8 @@ func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) stri
|
||||
return baseURL + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// startTokenCleanup starts the token cleanup goroutine
|
||||
// startTokenCleanup starts background goroutines for periodically cleaning up
|
||||
// the token cache, token blacklist cache, and JWK cache using the autoCleanupRoutine helper.
|
||||
func (t *TraefikOidc) startTokenCleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
|
||||
go func() {
|
||||
@@ -1069,7 +1230,14 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
}()
|
||||
}
|
||||
|
||||
// RevokeToken adds the token to the blacklist
|
||||
// RevokeToken handles local revocation of a token.
|
||||
// It removes the token from the validation cache (tokenCache) and adds the raw
|
||||
// token string to the blacklist cache (tokenBlacklist) with a default expiration (24h).
|
||||
// This prevents the token from being validated successfully even if it hasn't expired yet.
|
||||
// Note: This does *not* revoke the token with the OIDC provider.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to revoke locally.
|
||||
func (t *TraefikOidc) RevokeToken(token string) {
|
||||
// Remove from cache
|
||||
t.tokenCache.Delete(token)
|
||||
@@ -1080,7 +1248,17 @@ func (t *TraefikOidc) RevokeToken(token string) {
|
||||
t.tokenBlacklist.Set(token, true, time.Until(expiry))
|
||||
}
|
||||
|
||||
// RevokeTokenWithProvider revokes the token with the provider
|
||||
// RevokeTokenWithProvider attempts to revoke a token directly with the OIDC provider
|
||||
// using the revocation endpoint specified in the provider metadata or configuration.
|
||||
// It sends a POST request with the token, token_type_hint, client_id, and client_secret.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The token (e.g., refresh token or access token) to revoke.
|
||||
// - tokenType: The type hint for the token being revoked (e.g., "refresh_token").
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the revocation request is successful (provider returns 200 OK).
|
||||
// - An error if the request fails or the provider returns a non-OK status.
|
||||
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
t.logger.Debugf("Revoking token with provider")
|
||||
|
||||
@@ -1117,7 +1295,20 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshToken refreshes the user's token, protected by a mutex within the session.
|
||||
// refreshToken attempts to use the refresh token stored in the session to obtain a new set of tokens.
|
||||
// It acquires a mutex associated with the session to prevent concurrent refresh attempts for the same session.
|
||||
// It retrieves the refresh token, calls the TokenExchanger's GetNewTokenWithRefreshToken method,
|
||||
// verifies the newly obtained ID token using verifyToken, updates the session with the new tokens,
|
||||
// and saves the session.
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer (needed for saving the updated session).
|
||||
// - req: The HTTP request (needed for saving the updated session).
|
||||
// - session: The user's SessionData object containing the refresh token.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the token refresh was successful and the session was updated.
|
||||
// - false if no refresh token was found, the refresh exchange failed, the new token failed verification, or saving the session failed.
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
|
||||
// Lock the mutex specific to this session instance before attempting refresh
|
||||
session.refreshMutex.Lock()
|
||||
@@ -1159,7 +1350,16 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return true
|
||||
}
|
||||
|
||||
// isAllowedDomain checks if the user's email domain is allowed
|
||||
// isAllowedDomain checks if the domain part of the provided email address is present
|
||||
// in the configured list of allowed domains (t.allowedUserDomains).
|
||||
// If the allowed domains list is empty, all domains are considered allowed.
|
||||
//
|
||||
// Parameters:
|
||||
// - email: The email address to check.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the domain is allowed or if no domain restrictions are configured.
|
||||
// - false if the email format is invalid or the domain is not in the allowed list.
|
||||
func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
if len(t.allowedUserDomains) == 0 {
|
||||
return true // If no domains are specified, all are allowed
|
||||
@@ -1175,7 +1375,18 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// extractGroupsAndRoles extracts groups and roles from the id_token
|
||||
// extractGroupsAndRoles attempts to extract 'groups' and 'roles' claims from a decoded ID token.
|
||||
// It expects these claims, if present, to be arrays of strings.
|
||||
// It uses the configured extractClaimsFunc (which defaults to the package-level extractClaims)
|
||||
// to get the claims map from the token string.
|
||||
//
|
||||
// Parameters:
|
||||
// - idToken: The raw ID token string.
|
||||
//
|
||||
// Returns:
|
||||
// - A slice of strings containing the groups found in the 'groups' claim.
|
||||
// - A slice of strings containing the roles found in the 'roles' claim.
|
||||
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present but not arrays of strings.
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
@@ -1216,7 +1427,17 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
return groups, roles, nil
|
||||
}
|
||||
|
||||
// buildFullURL constructs a full URL from scheme, host and path
|
||||
// buildFullURL constructs an absolute URL string from its components.
|
||||
// If the provided path already starts with "http://" or "https://", it's returned directly.
|
||||
// Otherwise, it combines the scheme, host, and path, ensuring the path starts with a '/'.
|
||||
//
|
||||
// Parameters:
|
||||
// - scheme: The URL scheme ("http" or "https").
|
||||
// - host: The host part of the URL (e.g., "example.com:8080").
|
||||
// - path: The path part of the URL (e.g., "/resource").
|
||||
//
|
||||
// Returns:
|
||||
// - The combined absolute URL string (e.g., "https://example.com:8080/resource").
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
// If the path is already a full URL, return it as-is
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
@@ -1233,21 +1454,35 @@ func buildFullURL(scheme, host, path string) string {
|
||||
|
||||
// --- TokenExchanger Interface Implementation ---
|
||||
|
||||
// ExchangeCodeForToken implements the TokenExchanger interface.
|
||||
// It calls the existing exchangeTokens helper function.
|
||||
// ExchangeCodeForToken provides the implementation for the TokenExchanger interface method.
|
||||
// It directly calls the internal exchangeTokens method, passing through the arguments.
|
||||
// This allows the TraefikOidc struct to act as its own default TokenExchanger, while
|
||||
// still allowing mocking for tests.
|
||||
func (t *TraefikOidc) ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
// Note: The original exchangeTokens helper is defined in helpers.go and is already a method on *TraefikOidc
|
||||
return t.exchangeTokens(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
|
||||
}
|
||||
|
||||
// GetNewTokenWithRefreshToken implements the TokenExchanger interface.
|
||||
// It calls the existing getNewTokenWithRefreshToken helper function.
|
||||
// GetNewTokenWithRefreshToken provides the implementation for the TokenExchanger interface method.
|
||||
// It directly calls the internal getNewTokenWithRefreshToken helper method.
|
||||
// This allows the TraefikOidc struct to act as its own default TokenExchanger, while
|
||||
// still allowing mocking for tests.
|
||||
func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
// Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc
|
||||
return t.getNewTokenWithRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
// sendErrorResponse sends an error response, adapting to the client's Accept header.
|
||||
// sendErrorResponse sends an error response to the client, adapting the format based
|
||||
// on the request's Accept header. If the client prefers "application/json", it sends
|
||||
// a JSON object with "error", "error_description", and "status_code" fields.
|
||||
// Otherwise, it sends a basic HTML error page containing the message and a link
|
||||
// back to the application root or the original incoming path (if available from the session).
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer.
|
||||
// - req: The HTTP request (used to check Accept header and potentially get session).
|
||||
// - message: The error message to display/include in the response.
|
||||
// - code: The HTTP status code to set for the response.
|
||||
func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
|
||||
|
||||
+301
-5
@@ -376,6 +376,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
setupSession func(*SessionData)
|
||||
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
|
||||
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks
|
||||
requestHeaders map[string]string // Added for setting headers like Accept
|
||||
}{
|
||||
{
|
||||
name: "Excluded URL",
|
||||
@@ -394,7 +395,13 @@ func TestServeHTTP(t *testing.T) {
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ts.token) // Use the valid token generated in Setup
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
session.SetAccessToken(freshToken)
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
@@ -450,14 +457,161 @@ func TestServeHTTP(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Logout URL",
|
||||
requestPath: "/logout", // Assuming logout path is configured or defaulted correctly
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ts.token)
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
session.SetAccessToken(freshToken)
|
||||
},
|
||||
expectedStatus: http.StatusFound, // Expect redirect after logout
|
||||
expectedBody: "",
|
||||
// No specific session assertion needed for logout redirect itself
|
||||
},
|
||||
{
|
||||
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(createExpiredToken())
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
// Simulate failed refresh
|
||||
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
|
||||
}
|
||||
},
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client
|
||||
expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`,
|
||||
},
|
||||
{
|
||||
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(createExpiredToken())
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
// Simulate failed refresh
|
||||
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
|
||||
}
|
||||
},
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "text/html", // Browser client
|
||||
},
|
||||
expectedStatus: http.StatusFound, // Expect redirect for browser client
|
||||
},
|
||||
{
|
||||
name: "Authenticated request with token nearing expiry (needs refresh)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
// Create token expiring soon (e.g., 30s, within default 60s grace period)
|
||||
exp := time.Now().Add(30 * time.Second).Unix()
|
||||
iat := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry")
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
if refreshToken != "valid-refresh-token-for-near-expiry" {
|
||||
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
|
||||
}
|
||||
// Simulate successful refresh
|
||||
newToken := createNewValidToken()
|
||||
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-near-expiry", ExpiresIn: 3600}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusOK, // Expect success after proactive refresh
|
||||
expectedBody: "OK",
|
||||
},
|
||||
{
|
||||
name: "Authenticated request with token valid (outside grace period)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
// Create token expiring later (e.g., 10 mins, outside default 60s grace period)
|
||||
exp := time.Now().Add(10 * time.Minute).Unix()
|
||||
iat := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
||||
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(validToken)
|
||||
session.SetRefreshToken("should-not-be-used-refresh-token")
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
// This should NOT be called
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
t.Errorf("Refresh token function was called unexpectedly for valid token outside grace period")
|
||||
return nil, fmt.Errorf("refresh should not have been attempted")
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusOK, // Expect success, no refresh needed
|
||||
expectedBody: "OK",
|
||||
},
|
||||
{
|
||||
name: "Disallowed Domain (Accept: JSON)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
session.SetAccessToken(freshToken)
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
},
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`,
|
||||
},
|
||||
{
|
||||
name: "Disallowed Domain (Accept: HTML)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
session.SetAccessToken(freshToken)
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
},
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "text/html",
|
||||
},
|
||||
expectedStatus: http.StatusForbidden, // Still Forbidden, but HTML response
|
||||
expectedBody: "", // Body check is harder for HTML, focus on status and content-type
|
||||
},
|
||||
}
|
||||
|
||||
@@ -468,6 +622,12 @@ func TestServeHTTP(t *testing.T) {
|
||||
req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that
|
||||
req.Header.Set("X-Forwarded-Host", "testhost.com")
|
||||
req.Host = "testhost.com" // Also set Host header
|
||||
// Set request headers from test case
|
||||
if tc.requestHeaders != nil {
|
||||
for key, value := range tc.requestHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
@@ -527,10 +687,19 @@ func TestServeHTTP(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check response body if expected
|
||||
// Check response body if expected (handle JSON vs HTML)
|
||||
if tc.expectedBody != "" {
|
||||
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
||||
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
|
||||
// For JSON, compare directly
|
||||
if strings.Contains(rr.Header().Get("Content-Type"), "application/json") {
|
||||
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
||||
t.Errorf("Test %s: Expected JSON body %q, got %q", tc.name, tc.expectedBody, body)
|
||||
}
|
||||
} else if tc.expectedBody == "OK" { // Simple check for the "OK" body from next handler
|
||||
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
||||
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
|
||||
}
|
||||
}
|
||||
// Add more sophisticated HTML body checks if needed
|
||||
}
|
||||
|
||||
// Perform post-request session assertions if defined
|
||||
@@ -2319,3 +2488,130 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
||||
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyTimeConstraint tests the time constraint verification logic with separate past/future skew tolerances.
|
||||
func TestVerifyTimeConstraint(t *testing.T) {
|
||||
// Define tolerances used in jwt.go (ensure they match)
|
||||
toleranceFuture := 2 * time.Minute
|
||||
tolerancePast := 10 * time.Second
|
||||
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claimTime time.Time
|
||||
claimName string
|
||||
futureCheck bool // true for exp, false for iat/nbf
|
||||
expectError bool
|
||||
}{
|
||||
// Expiration (future=true, tolerance=2min)
|
||||
{
|
||||
name: "EXP: Valid (expires in 1 min)",
|
||||
claimTime: now.Add(1 * time.Minute),
|
||||
claimName: "exp",
|
||||
futureCheck: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "EXP: Expired (expired 3 min ago)",
|
||||
claimTime: now.Add(-3 * time.Minute), // Outside 2min tolerance
|
||||
claimName: "exp",
|
||||
futureCheck: true,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "EXP: Valid (expired 1 min ago, within 2min tolerance)",
|
||||
claimTime: now.Add(-1 * time.Minute), // Inside 2min tolerance
|
||||
claimName: "exp",
|
||||
futureCheck: true,
|
||||
expectError: false, // Should be allowed due to future tolerance
|
||||
},
|
||||
|
||||
// Issued At (future=false, tolerance=10s)
|
||||
{
|
||||
name: "IAT: Valid (issued 1 min ago)",
|
||||
claimTime: now.Add(-1 * time.Minute),
|
||||
claimName: "iat",
|
||||
futureCheck: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IAT: Invalid (issued 15 sec in future)",
|
||||
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
|
||||
claimName: "iat",
|
||||
futureCheck: false,
|
||||
expectError: true, // "token used before issued"
|
||||
},
|
||||
{
|
||||
name: "IAT: Valid (issued 5 sec in future, within 10s tolerance)",
|
||||
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
|
||||
claimName: "iat",
|
||||
futureCheck: false,
|
||||
expectError: false, // Should be allowed due to past tolerance
|
||||
},
|
||||
|
||||
// Not Before (future=false, tolerance=10s)
|
||||
{
|
||||
name: "NBF: Valid (active 1 min ago)",
|
||||
claimTime: now.Add(-1 * time.Minute),
|
||||
claimName: "nbf",
|
||||
futureCheck: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "NBF: Invalid (active in 15 sec)",
|
||||
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
|
||||
claimName: "nbf",
|
||||
futureCheck: false,
|
||||
expectError: true, // "token not yet valid"
|
||||
},
|
||||
{
|
||||
name: "NBF: Valid (active in 5 sec, within 10s tolerance)",
|
||||
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
|
||||
claimName: "nbf",
|
||||
futureCheck: false,
|
||||
expectError: false, // Should be allowed due to past tolerance
|
||||
},
|
||||
}
|
||||
|
||||
// Temporarily adjust global tolerances for test consistency, then restore
|
||||
originalFutureTolerance := ClockSkewToleranceFuture
|
||||
originalPastTolerance := ClockSkewTolerancePast
|
||||
ClockSkewToleranceFuture = toleranceFuture
|
||||
ClockSkewTolerancePast = tolerancePast
|
||||
defer func() {
|
||||
ClockSkewToleranceFuture = originalFutureTolerance
|
||||
ClockSkewTolerancePast = originalPastTolerance
|
||||
}()
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Convert claim time to float64 unix timestamp
|
||||
unixTime := float64(tc.claimTime.Unix()) + float64(tc.claimTime.Nanosecond())/1e9
|
||||
|
||||
var err error
|
||||
// Call the specific verification function which uses verifyTimeConstraint
|
||||
if tc.claimName == "exp" {
|
||||
err = verifyExpiration(unixTime)
|
||||
} else if tc.claimName == "iat" {
|
||||
err = verifyIssuedAt(unixTime)
|
||||
} else if tc.claimName == "nbf" {
|
||||
err = verifyNotBefore(unixTime)
|
||||
} else {
|
||||
t.Fatalf("Unknown claim name in test setup: %s", tc.claimName)
|
||||
}
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for claim %s at time %v (now=%v), but got nil", tc.claimName, tc.claimTime, now)
|
||||
} else {
|
||||
t.Logf("Got expected error: %v", err) // Log the error for confirmation
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for claim %s at time %v (now=%v), but got: %v", tc.claimName, tc.claimTime, now, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
} // Add missing closing brace for TestVerifyTimeConstraint
|
||||
|
||||
+28
-2
@@ -15,6 +15,8 @@ type MetadataCache struct {
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// NewMetadataCache creates a new MetadataCache instance.
|
||||
// It initializes the cache structure and starts the background cleanup goroutine.
|
||||
func NewMetadataCache() *MetadataCache {
|
||||
c := &MetadataCache{
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
@@ -24,7 +26,8 @@ func NewMetadataCache() *MetadataCache {
|
||||
return c
|
||||
}
|
||||
|
||||
// Cleanup removes expired metadata from the cache.
|
||||
// Cleanup removes the cached provider metadata if it has expired.
|
||||
// This is called periodically by the auto-cleanup goroutine.
|
||||
func (c *MetadataCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -35,11 +38,31 @@ func (c *MetadataCache) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// isCacheValid checks if the cached metadata is present and has not expired.
|
||||
// Note: This function assumes the read lock is held or it's called from a context
|
||||
// where the lock is already held (like within GetMetadata after locking).
|
||||
func (c *MetadataCache) isCacheValid() bool {
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
}
|
||||
|
||||
// GetMetadata retrieves the metadata from cache or fetches it if expired
|
||||
// GetMetadata retrieves the OIDC provider metadata.
|
||||
// It first checks the cache for valid, non-expired metadata. If found, it's returned immediately.
|
||||
// If the cache is empty or expired, it attempts to fetch the metadata from the provider's
|
||||
// well-known endpoint using discoverProviderMetadata.
|
||||
// If fetching is successful, the new metadata is cached for 1 hour.
|
||||
// If fetching fails but valid metadata exists in the cache (even if expired), the cache expiry
|
||||
// is extended by 5 minutes, and the cached data is returned to prevent thundering herd issues.
|
||||
// If fetching fails and there's no cached data, an error is returned.
|
||||
// It employs double-checked locking for thread safety and performance.
|
||||
//
|
||||
// Parameters:
|
||||
// - providerURL: The base URL of the OIDC provider.
|
||||
// - httpClient: The HTTP client to use for fetching metadata.
|
||||
// - logger: The logger instance for recording errors or warnings.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the ProviderMetadata struct.
|
||||
// - An error if metadata cannot be retrieved from cache or fetched from the provider.
|
||||
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
|
||||
c.mutex.RLock()
|
||||
if c.isCacheValid() {
|
||||
@@ -76,10 +99,13 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// startAutoCleanup starts the background goroutine that periodically calls Cleanup
|
||||
// to remove expired metadata from the cache.
|
||||
func (c *MetadataCache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
// Close stops the automatic cleanup goroutine associated with this metadata cache.
|
||||
func (c *MetadataCache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
+186
-42
@@ -16,8 +16,15 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// generateSecureRandomString creates a cryptographically secure random string of specified length.
|
||||
// It returns the generated string or an error if random generation fails.
|
||||
// generateSecureRandomString creates a cryptographically secure, hex-encoded random string.
|
||||
// It reads the specified number of bytes from crypto/rand and encodes them as a hexadecimal string.
|
||||
//
|
||||
// Parameters:
|
||||
// - length: The number of random bytes to generate (the resulting hex string will be twice this length).
|
||||
//
|
||||
// Returns:
|
||||
// - A hex-encoded random string.
|
||||
// - An error if reading random bytes fails.
|
||||
func generateSecureRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
@@ -56,7 +63,14 @@ const (
|
||||
minEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// compressToken compresses a token using gzip and base64 encodes it.
|
||||
// compressToken compresses the input string using gzip and then encodes the result using standard base64 encoding.
|
||||
// If any error occurs during compression, it returns the original uncompressed token as a fallback.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The string to compress.
|
||||
//
|
||||
// Returns:
|
||||
// - The base64 encoded, gzipped string, or the original string if compression fails.
|
||||
func compressToken(token string) string {
|
||||
var b bytes.Buffer
|
||||
gz := gzip.NewWriter(&b)
|
||||
@@ -69,7 +83,15 @@ func compressToken(token string) string {
|
||||
return base64.StdEncoding.EncodeToString(b.Bytes())
|
||||
}
|
||||
|
||||
// decompressToken decompresses a base64 encoded gzipped token.
|
||||
// decompressToken decodes a standard base64 encoded string and then decompresses the result using gzip.
|
||||
// If base64 decoding or gzip decompression fails, it returns the original input string as a fallback,
|
||||
// assuming it might not have been compressed.
|
||||
//
|
||||
// Parameters:
|
||||
// - compressed: The base64 encoded, gzipped string.
|
||||
//
|
||||
// Returns:
|
||||
// - The decompressed original string, or the input string if decompression fails.
|
||||
func decompressToken(compressed string) string {
|
||||
data, err := base64.StdEncoding.DecodeString(compressed)
|
||||
if err != nil {
|
||||
@@ -140,15 +162,15 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// getSessionOptions returns secure session options configured for the current request.
|
||||
// Parameters:
|
||||
// - isSecure: Whether the current request is using HTTPS.
|
||||
// getSessionOptions returns a sessions.Options struct configured with security best practices.
|
||||
// It sets HttpOnly to true, Secure based on the request scheme or forceHTTPS setting,
|
||||
// SameSite to LaxMode, MaxAge to the absoluteSessionTimeout, and Path to "/".
|
||||
//
|
||||
// The options ensure cookies are:
|
||||
// - HTTP-only (not accessible via JavaScript)
|
||||
// - Secure when using HTTPS or when forceHTTPS is enabled
|
||||
// - Using SameSite=Lax for CSRF protection
|
||||
// - Set with appropriate timeout and path settings
|
||||
// Parameters:
|
||||
// - isSecure: A boolean indicating if the current request context is secure (HTTPS).
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a configured sessions.Options struct.
|
||||
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
@@ -210,11 +232,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// getTokenChunkSessions retrieves all session chunks for a given token type.
|
||||
// getTokenChunkSessions retrieves all cookie chunks associated with a large token (access or refresh).
|
||||
// It iteratively attempts to load cookies named "{baseName}_0", "{baseName}_1", etc., until
|
||||
// a cookie is not found or returns an error. The loaded sessions are stored in the provided chunks map.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request
|
||||
// - baseName: The base name for the token's session cookies
|
||||
// - chunks: Map to store the chunks in
|
||||
// - r: The incoming HTTP request containing the cookies.
|
||||
// - baseName: The base name of the cookie (e.g., accessTokenCookie).
|
||||
// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks) to populate with the found session chunks.
|
||||
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||
@@ -258,10 +283,16 @@ type SessionData struct {
|
||||
refreshMutex sync.Mutex
|
||||
}
|
||||
|
||||
// Save persists all session data to cookies in the HTTP response.
|
||||
// It saves the main session, token sessions, and any token chunks,
|
||||
// applying appropriate security options to each cookie. All cookies
|
||||
// are saved with consistent security settings based on the request scheme.
|
||||
// Save persists all parts of the session (main, access token, refresh token, and any chunks)
|
||||
// back to the client as cookies in the HTTP response. It applies secure cookie options
|
||||
// obtained via getSessionOptions based on the request's security context.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The original HTTP request (used to determine security context for cookie options).
|
||||
// - w: The HTTP response writer to which the Set-Cookie headers will be added.
|
||||
//
|
||||
// Returns:
|
||||
// - An error if saving any of the session components fails.
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
@@ -305,7 +336,19 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all session data by expiring all cookies and clearing their values.
|
||||
// Clear removes all session data associated with this SessionData instance.
|
||||
// It clears the values map of the main, access, and refresh sessions, sets their MaxAge to -1
|
||||
// to expire the cookies immediately, and clears any associated token chunk cookies.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired sessions to send the
|
||||
// expiring Set-Cookie headers. Finally, it clears internal fields and returns the SessionData
|
||||
// object to the pool.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request (required by the underlying session store).
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
//
|
||||
// Returns:
|
||||
// - An error if saving the expired sessions fails (only if w is not nil).
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions.
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
@@ -340,7 +383,12 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// clearTokenChunks removes all session chunks for a given token type.
|
||||
// clearTokenChunks iterates through a map of session chunks, clears their values,
|
||||
// and sets their MaxAge to -1 to expire them. This is used internally by Clear.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request (required by the underlying session store, though not directly used here).
|
||||
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
@@ -350,7 +398,12 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthenticated returns whether the current session is authenticated.
|
||||
// GetAuthenticated checks if the session is marked as authenticated and has not exceeded
|
||||
// the absolute session timeout.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the "authenticated" flag is set to true and the session creation time is within the allowed timeout.
|
||||
// - false otherwise.
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
if !auth {
|
||||
@@ -365,8 +418,15 @@ func (sd *SessionData) GetAuthenticated() bool {
|
||||
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
|
||||
}
|
||||
|
||||
// SetAuthenticated updates the session's authentication status and rotates session ID.
|
||||
// Returns an error if generating a new session ID fails.
|
||||
// SetAuthenticated sets the authentication status of the session.
|
||||
// If setting to true, it generates a new secure session ID for the main session
|
||||
// to prevent session fixation attacks and records the current time as the creation time.
|
||||
//
|
||||
// Parameters:
|
||||
// - value: The boolean authentication status (true for authenticated, false otherwise).
|
||||
//
|
||||
// Returns:
|
||||
// - An error if generating a new session ID fails when setting value to true.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
if value {
|
||||
id, err := generateSecureRandomString(32)
|
||||
@@ -380,7 +440,12 @@ func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccessToken retrieves the complete access token from the session.
|
||||
// GetAccessToken retrieves the access token stored in the session.
|
||||
// It handles reassembling the token from multiple cookie chunks if necessary
|
||||
// and decompresses it if it was stored compressed.
|
||||
//
|
||||
// Returns:
|
||||
// - The complete, decompressed access token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -414,7 +479,14 @@ func (sd *SessionData) GetAccessToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// SetAccessToken stores the access token in the session.
|
||||
// SetAccessToken stores the provided access token in the session.
|
||||
// It first expires any existing access token chunk cookies.
|
||||
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
|
||||
// it's stored directly in the primary access token session. Otherwise, the compressed token
|
||||
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_a_0, _oidc_raczylo_a_1, etc.).
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The access token string to store.
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
@@ -444,7 +516,12 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetRefreshToken retrieves the complete refresh token from the session.
|
||||
// GetRefreshToken retrieves the refresh token stored in the session.
|
||||
// It handles reassembling the token from multiple cookie chunks if necessary
|
||||
// and decompresses it if it was stored compressed.
|
||||
//
|
||||
// Returns:
|
||||
// - The complete, decompressed refresh token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -478,7 +555,14 @@ func (sd *SessionData) GetRefreshToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// SetRefreshToken stores the refresh token in the session.
|
||||
// SetRefreshToken stores the provided refresh token in the session.
|
||||
// It first expires any existing refresh token chunk cookies.
|
||||
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
|
||||
// it's stored directly in the primary refresh token session. Otherwise, the compressed token
|
||||
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_r_0, _oidc_raczylo_r_1, etc.).
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The refresh token string to store.
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
@@ -508,7 +592,13 @@ func (sd *SessionData) SetRefreshToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireAccessTokenChunks expires any existing access token chunk cookies.
|
||||
// expireAccessTokenChunks finds all existing access token chunk cookies (_oidc_raczylo_a_N)
|
||||
// associated with the current request, clears their values, and sets their MaxAge to -1.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
|
||||
// the expiring Set-Cookie headers. This is used internally when setting a new access token.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
@@ -526,7 +616,13 @@ func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
|
||||
// expireRefreshTokenChunks finds all existing refresh token chunk cookies (_oidc_raczylo_r_N)
|
||||
// associated with the current request, clears their values, and sets their MaxAge to -1.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
|
||||
// the expiring Set-Cookie headers. This is used internally when setting a new refresh token.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
@@ -544,7 +640,15 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size.
|
||||
// splitIntoChunks divides a string `s` into a slice of strings, where each element
|
||||
// has a maximum length of `chunkSize`.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The string to split.
|
||||
// - chunkSize: The maximum size of each chunk.
|
||||
//
|
||||
// Returns:
|
||||
// - A slice of strings representing the chunks.
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
@@ -559,57 +663,97 @@ func splitIntoChunks(s string, chunkSize int) []string {
|
||||
return chunks
|
||||
}
|
||||
|
||||
// GetCSRF retrieves the CSRF token from the session.
|
||||
// GetCSRF retrieves the Cross-Site Request Forgery (CSRF) token stored in the main session.
|
||||
//
|
||||
// Returns:
|
||||
// - The CSRF token string, or an empty string if not set.
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF stores a new CSRF token in the session.
|
||||
// SetCSRF stores the provided CSRF token string in the main session.
|
||||
// This token is typically generated at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The CSRF token to store.
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce retrieves the nonce value from the session.
|
||||
// GetNonce retrieves the OIDC nonce value stored in the main session.
|
||||
// The nonce is used to associate an ID token with the specific authentication request.
|
||||
//
|
||||
// Returns:
|
||||
// - The nonce string, or an empty string if not set.
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce stores a new nonce value in the session.
|
||||
// SetNonce stores the provided OIDC nonce string in the main session.
|
||||
// This nonce is typically generated at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - nonce: The nonce string to store.
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetCodeVerifier retrieves the PKCE code verifier from the session.
|
||||
// GetCodeVerifier retrieves the PKCE (Proof Key for Code Exchange) code verifier
|
||||
// stored in the main session. This is only relevant if PKCE is enabled.
|
||||
//
|
||||
// Returns:
|
||||
// - The code verifier string, or an empty string if not set or PKCE is disabled.
|
||||
func (sd *SessionData) GetCodeVerifier() string {
|
||||
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
||||
return codeVerifier
|
||||
}
|
||||
|
||||
// SetCodeVerifier stores the PKCE code verifier in the session.
|
||||
// SetCodeVerifier stores the provided PKCE code verifier string in the main session.
|
||||
// This is typically called at the start of the authentication flow if PKCE is enabled.
|
||||
//
|
||||
// Parameters:
|
||||
// - codeVerifier: The PKCE code verifier string to store.
|
||||
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
sd.mainSession.Values["code_verifier"] = codeVerifier
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address from the session.
|
||||
// GetEmail retrieves the authenticated user's email address stored in the main session.
|
||||
// This is typically extracted from the ID token claims after successful authentication.
|
||||
//
|
||||
// Returns:
|
||||
// - The user's email address string, or an empty string if not set.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail stores the user's email address in the session.
|
||||
// SetEmail stores the provided user email address string in the main session.
|
||||
// This is typically called after successful authentication and claim extraction.
|
||||
//
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
|
||||
// GetIncomingPath retrieves the original request URI (including query parameters)
|
||||
// that the user was trying to access before being redirected for authentication.
|
||||
// This is stored in the main session to allow redirection back after successful login.
|
||||
//
|
||||
// Returns:
|
||||
// - The original request URI string, or an empty string if not set.
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath stores the original request path that triggered the authentication flow.
|
||||
// SetIncomingPath stores the original request URI (path and query parameters)
|
||||
// in the main session. This is typically called at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - path: The original request URI string (e.g., "/protected/resource?id=123").
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
|
||||
+91
-33
@@ -114,6 +114,14 @@ const (
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
// - ForceHTTPS: true (for security)
|
||||
// - EnablePKCE: false (PKCE is opt-in)
|
||||
//
|
||||
// CreateConfig initializes a new Config struct with default values for optional fields.
|
||||
// It sets default scopes, log level, rate limit, enables ForceHTTPS, and sets the
|
||||
// default refresh grace period. Required fields like ProviderURL, ClientID, ClientSecret,
|
||||
// CallbackURL, and SessionEncryptionKey must be set explicitly after creation.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a new Config struct with default settings applied.
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
@@ -127,9 +135,14 @@ func CreateConfig() *Config {
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate performs validation checks on the Config.
|
||||
// It ensures all required fields are set and have valid values.
|
||||
// Returns an error if any validation check fails.
|
||||
// Validate checks the configuration settings for validity.
|
||||
// It ensures that required fields (ProviderURL, CallbackURL, ClientID, ClientSecret, SessionEncryptionKey)
|
||||
// are present and that URLs are well-formed (HTTPS where required). It also validates
|
||||
// the session key length, log level, rate limit, and refresh grace period.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the configuration is valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (c *Config) Validate() error {
|
||||
// Validate provider URL
|
||||
if c.ProviderURL == "" {
|
||||
@@ -211,13 +224,26 @@ func (c *Config) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidSecureURL checks if the provided string is a valid HTTPS URL
|
||||
// isValidSecureURL checks if a given string represents a valid, absolute HTTPS URL.
|
||||
// It uses url.Parse and checks for a nil error, an "https" scheme, and a non-empty host.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The URL string to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the string is a valid HTTPS URL, false otherwise.
|
||||
func isValidSecureURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme == "https" && u.Host != ""
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level is valid
|
||||
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
|
||||
//
|
||||
// Parameters:
|
||||
// - level: The log level string to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the log level is valid, false otherwise.
|
||||
func isValidLogLevel(level string) bool {
|
||||
return level == "debug" || level == "info" || level == "error"
|
||||
}
|
||||
@@ -234,14 +260,20 @@ type Logger struct {
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger with the specified log level.
|
||||
// The log level determines which messages are output:
|
||||
// - "debug": Outputs all messages (debug, info, error)
|
||||
// - "info": Outputs info and error messages
|
||||
// - "error": Outputs only error messages
|
||||
// NewLogger creates and configures a new Logger instance based on the provided log level.
|
||||
// It initializes loggers for ERROR (stderr), INFO (stdout), and DEBUG (stdout) levels,
|
||||
// enabling output based on the specified level:
|
||||
// - "error": Only ERROR messages are output.
|
||||
// - "info": INFO and ERROR messages are output.
|
||||
// - "debug": DEBUG, INFO, and ERROR messages are output.
|
||||
//
|
||||
// Error messages are always written to stderr, while info and debug
|
||||
// messages are written to stdout when enabled.
|
||||
// If an invalid level is provided, it defaults to behavior similar to "error".
|
||||
//
|
||||
// Parameters:
|
||||
// - logLevel: The desired logging level ("debug", "info", or "error").
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the configured Logger instance.
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
@@ -263,51 +295,77 @@ func NewLogger(logLevel string) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an informational message.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
// Info logs a message at the INFO level using Printf style formatting.
|
||||
// Output is directed to stdout if the configured log level is "info" or "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
// Debug logs a message at the DEBUG level using Printf style formatting.
|
||||
// Output is directed to stdout only if the configured log level is "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
// Error logs a message at the ERROR level using Printf style formatting.
|
||||
// Output is always directed to stderr, regardless of the configured log level.
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs an informational message using Printf formatting.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
// Infof logs a message at the INFO level using Printf style formatting.
|
||||
// Equivalent to calling l.Info(format, args...).
|
||||
// Output is directed to stdout if the configured log level is "info" or "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message using Printf formatting.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
// Debugf logs a message at the DEBUG level using Printf style formatting.
|
||||
// Equivalent to calling l.Debug(format, args...).
|
||||
// Output is directed to stdout only if the configured log level is "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error message using Printf formatting.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
// Errorf logs a message at the ERROR level using Printf style formatting.
|
||||
// Equivalent to calling l.Error(format, args...).
|
||||
// Output is always directed to stderr, regardless of the configured log level.
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// handleError writes an error message to both the HTTP response and the error log.
|
||||
// It ensures consistent error handling across the middleware by logging the error
|
||||
// and sending an appropriate HTTP response to the client.
|
||||
// handleError logs an error message using the provided logger and sends an HTTP error
|
||||
// response to the client with the specified message and status code.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The http.ResponseWriter to send the error response to.
|
||||
// - message: The error message string.
|
||||
// - code: The HTTP status code for the response.
|
||||
// - logger: The Logger instance to use for logging the error.
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error(message)
|
||||
http.Error(w, message, code)
|
||||
|
||||
Reference in New Issue
Block a user