diff --git a/cache.go b/cache.go index 759ec1a..d35696f 100644 --- a/cache.go +++ b/cache.go @@ -5,26 +5,41 @@ import ( "time" ) -// CacheItem represents an item in the cache +// CacheItem represents an item stored in the cache with its associated metadata. type CacheItem struct { - Value interface{} + // Value is the cached data of any type + Value interface{} + + // ExpiresAt is the timestamp when this item should be considered expired + // and removed from the cache during cleanup operations ExpiresAt time.Time } -// Cache is a simple in-memory cache +// Cache provides a thread-safe in-memory caching mechanism with expiration support. +// It uses a read-write mutex to ensure safe concurrent access to the cached items. type Cache struct { + // items stores the cached data with string keys items map[string]CacheItem + + // mutex protects concurrent access to the items map + // Use RLock/RUnlock for reads and Lock/Unlock for writes mutex sync.RWMutex } -// NewCache creates a new Cache +// NewCache creates a new empty cache instance. +// The cache is immediately ready for use and is thread-safe. func NewCache() *Cache { return &Cache{ items: make(map[string]CacheItem), } } -// Set adds an item to the cache +// Set adds or updates an item in the cache with the specified expiration duration. +// Parameters: +// - key: Unique identifier for the cached item +// - value: The data to cache (can be of any type) +// - expiration: How long the item should remain in the cache +// Thread-safe: Uses write locking to ensure safe concurrent access. func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { c.mutex.Lock() defer c.mutex.Unlock() @@ -34,7 +49,13 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { } } -// Get retrieves an item from the cache +// Get retrieves an item from the cache if it exists and hasn't expired. +// Parameters: +// - key: The identifier of the item to retrieve +// Returns: +// - value: The cached data (nil if not found or expired) +// - found: true if the item was found and is valid, false otherwise +// Thread-safe: Uses read locking to ensure safe concurrent access. func (c *Cache) Get(key string) (interface{}, bool) { c.mutex.RLock() defer c.mutex.RUnlock() @@ -49,14 +70,19 @@ func (c *Cache) Get(key string) (interface{}, bool) { return item.Value, true } -// Delete removes an item from the cache +// Delete removes an item from the cache if it exists. +// If the item doesn't exist, this operation is a no-op. +// Thread-safe: Uses write locking to ensure safe concurrent access. func (c *Cache) Delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() delete(c.items, key) } -// Cleanup removes expired items from the cache +// Cleanup removes all expired items from the cache. +// This should be called periodically to prevent memory leaks from +// expired items that haven't been accessed (and thus not removed during Get operations). +// Thread-safe: Uses write locking to ensure safe concurrent access. func (c *Cache) Cleanup() { c.mutex.Lock() defer c.mutex.Unlock() diff --git a/helpers.go b/helpers.go index cd1e5a4..8825745 100644 --- a/helpers.go +++ b/helpers.go @@ -16,6 +16,13 @@ import ( "github.com/gorilla/sessions" ) +// newSessionOptions creates secure session cookie options. +// Parameters: +// - isSecure: Whether to set the Secure flag on cookies +// Returns session options configured for security with: +// - HttpOnly flag to prevent JavaScript access +// - SameSite=Lax for CSRF protection +// - Appropriate timeout and path settings func newSessionOptions(isSecure bool) *sessions.Options { return &sessions.Options{ HttpOnly: true, @@ -26,7 +33,10 @@ func newSessionOptions(isSecure bool) *sessions.Options { } } -// generateNonce generates a random nonce +// generateNonce creates a cryptographically secure random nonce +// for use in the OIDC authentication flow. The nonce is used to +// prevent replay attacks by ensuring the token received matches +// the authentication request. func generateNonce() (string, error) { nonceBytes := make([]byte, 32) _, err := rand.Read(nonceBytes) @@ -36,7 +46,33 @@ func generateNonce() (string, error) { return base64.URLEncoding.EncodeToString(nonceBytes), nil } -// exchangeTokens exchanges a code or refresh token for tokens +// TokenResponse represents the response from the OIDC token endpoint. +// It contains the various tokens and metadata returned after successful +// code exchange or token refresh operations. +type TokenResponse struct { + // IDToken is the OIDC ID token containing user claims + IDToken string `json:"id_token"` + + // AccessToken is the OAuth 2.0 access token for API access + AccessToken string `json:"access_token"` + + // RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens + RefreshToken string `json:"refresh_token"` + + // ExpiresIn is the lifetime in seconds of the access token + ExpiresIn int `json:"expires_in"` + + // TokenType is the type of token, typically "Bearer" + TokenType string `json:"token_type"` +} + +// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider. +// It supports both authorization code and refresh token grant types. +// Parameters: +// - ctx: Context for the HTTP request +// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token") +// - codeOrToken: Either the authorization code or refresh token +// - redirectURL: The callback URL for authorization code grant func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) { data := url.Values{ "grant_type": {grantType}, @@ -76,16 +112,8 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken return &tokenResponse, nil } -// TokenResponse represents the response from the token endpoint -type TokenResponse struct { - IDToken string `json:"id_token"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int `json:"expires_in"` - TokenType string `json:"token_type"` -} - -// getNewTokenWithRefreshToken refreshes the token using the refresh token +// getNewTokenWithRefreshToken obtains new tokens using a refresh token. +// This is used to refresh access tokens before they expire. func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { ctx := context.Background() tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "") @@ -94,24 +122,23 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe } t.logger.Debugf("Token response: %+v", tokenResponse) - return tokenResponse, nil } -// handleExpiredToken handles the case when a token has expired +// handleExpiredToken manages token expiration by clearing the session +// and initiating a new authentication flow. func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { - // Clear the existing session if err := session.Clear(req, rw); err != nil { t.logger.Errorf("Failed to clear session: %v", err) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } - - // Initialize new authentication t.defaultInitiateAuthentication(rw, req, session, redirectURL) } -// handleCallback handles the callback from the OIDC provider +// handleCallback processes the authentication callback from the OIDC provider. +// It validates the callback parameters, exchanges the authorization code for +// tokens, verifies the tokens, and establishes the user's session. func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) { session, err := t.sessionManager.GetSession(req) if err != nil { @@ -122,7 +149,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, t.logger.Debugf("Handling callback, URL: %s", req.URL.String()) - // Check for errors in the query parameters + // Check for errors in the callback if req.URL.Query().Get("error") != "" { errorDescription := req.URL.Query().Get("error_description") t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription) @@ -130,7 +157,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Validate state parameter matches the session's CSRF token + // Validate CSRF state state := req.URL.Query().Get("state") if state == "" { t.logger.Error("No state in callback") @@ -166,7 +193,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Verify and process tokens + // Verify tokens and claims if err := t.verifyToken(tokenResponse.IDToken); err != nil { t.logger.Errorf("Failed to verify id_token: %v", err) http.Error(rw, "Authentication failed", http.StatusInternalServerError) @@ -180,7 +207,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Verify nonce + // Verify nonce to prevent replay attacks nonceClaim, ok := claims["nonce"].(string) if !ok || nonceClaim == "" { t.logger.Error("Nonce claim missing in id_token") @@ -201,7 +228,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Process email + // Validate user's email domain email, _ := claims["email"].(string) if email == "" || !t.isAllowedDomain(email) { t.logger.Errorf("Invalid or disallowed email: %s", email) @@ -209,13 +236,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - // Update session with new values + // Update session with authentication data session.SetAuthenticated(true) session.SetEmail(email) session.SetAccessToken(tokenResponse.IDToken) session.SetRefreshToken(tokenResponse.RefreshToken) - // Save session if err := session.Save(req, rw); err != nil { t.logger.Errorf("Failed to save session: %v", err) http.Error(rw, "Failed to save session", http.StatusInternalServerError) @@ -231,7 +257,8 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, http.Redirect(rw, req, redirectPath, http.StatusFound) } -// extractClaims extracts claims from a JWT token +// extractClaims parses a JWT token and extracts its claims. +// It handles base64url decoding and JSON parsing of the token payload. func extractClaims(tokenString string) (map[string]interface{}, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { @@ -251,27 +278,32 @@ func extractClaims(tokenString string) (map[string]interface{}, error) { return claims, nil } -// TokenBlacklist maintains a blacklist of tokens +// TokenBlacklist maintains a thread-safe list of revoked tokens. +// It stores tokens with their expiration times and automatically +// removes expired entries during cleanup operations. type TokenBlacklist struct { + // blacklist maps token IDs to their expiration times blacklist map[string]time.Time - mutex sync.RWMutex + + // mutex protects concurrent access to the blacklist + mutex sync.RWMutex } -// NewTokenBlacklist creates a new TokenBlacklist +// NewTokenBlacklist creates a new TokenBlacklist instance. func NewTokenBlacklist() *TokenBlacklist { return &TokenBlacklist{ blacklist: make(map[string]time.Time), } } -// Add adds a token to the blacklist +// Add adds a token to the blacklist with an expiration time. func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) { tb.mutex.Lock() defer tb.mutex.Unlock() tb.blacklist[tokenID] = expiration } -// IsBlacklisted checks if a token is blacklisted +// IsBlacklisted checks if a token is in the blacklist and not expired. func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool { tb.mutex.RLock() defer tb.mutex.RUnlock() @@ -279,7 +311,7 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool { return exists && time.Now().Before(expiration) } -// Cleanup removes expired tokens from the blacklist +// Cleanup removes expired tokens from the blacklist. func (tb *TokenBlacklist) Cleanup() { tb.mutex.Lock() defer tb.mutex.Unlock() @@ -291,25 +323,29 @@ func (tb *TokenBlacklist) Cleanup() { } } -// TokenCache caches tokens +// TokenCache provides a caching mechanism for validated tokens. +// It stores token claims to avoid repeated validation of the +// same token, improving performance for frequently used tokens. type TokenCache struct { + // cache is the underlying cache implementation cache *Cache } -// NewTokenCache creates a new TokenCache +// NewTokenCache creates a new TokenCache instance. func NewTokenCache() *TokenCache { return &TokenCache{ cache: NewCache(), } } -// Set sets a token in the cache +// Set stores a token's claims in the cache with an expiration time. func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) { token = "t-" + token tc.cache.Set(token, claims, expiration) } -// Get retrieves a token from the cache +// Get retrieves a token's claims from the cache. +// Returns the claims and a boolean indicating if the token was found. func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) { token = "t-" + token value, found := tc.cache.Get(token) @@ -320,18 +356,18 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) { return claims, ok } -// Delete removes a token from the cache +// Delete removes a token from the cache. func (tc *TokenCache) Delete(token string) { token = "t-" + token tc.cache.Delete(token) } -// Cleanup cleans up expired tokens from the cache +// Cleanup removes expired tokens from the cache. func (tc *TokenCache) Cleanup() { tc.cache.Cleanup() } -// exchangeCodeForToken exchanges the authorization code for tokens +// exchangeCodeForToken exchanges an authorization code for tokens. func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) { ctx := context.Background() tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL) @@ -341,7 +377,8 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*To return tokenResponse, nil } -// createStringMap creates a map from a slice of strings +// createStringMap creates a map from a slice of strings. +// Used for efficient lookups in allowed domains and roles. func createStringMap(keys []string) map[string]struct{} { result := make(map[string]struct{}) for _, key := range keys { @@ -350,7 +387,9 @@ func createStringMap(keys []string) map[string]struct{} { return result } -// handleLogout handles the logout request +// handleLogout manages the OIDC logout process. +// It clears the session and redirects either to the OIDC provider's +// end session endpoint (if available) or to the configured post-logout URL. func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { session, err := t.sessionManager.GetSession(req) if err != nil { @@ -359,22 +398,18 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { return } - // Get the access token before clearing session accessToken := session.GetAccessToken() - // Clear all session data if err := session.Clear(req, rw); err != nil { t.logger.Errorf("Error clearing session: %v", err) http.Error(rw, "Session error", http.StatusInternalServerError) return } - // Get the base URL for redirects host := t.determineHost(req) scheme := t.determineScheme(req) baseURL := fmt.Sprintf("%s://%s", scheme, host) - // Determine post logout redirect URI postLogoutRedirectURI := t.postLogoutRedirectURI if postLogoutRedirectURI == "" { postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL) @@ -382,7 +417,6 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI) } - // If we have an end session endpoint and an access token, use OIDC end session if t.endSessionURL != "" && accessToken != "" { logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI) if err != nil { @@ -394,11 +428,14 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { return } - // Otherwise, redirect to post logout URI http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound) } -// BuildLogoutURL constructs the OIDC end session URL +// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters. +// Parameters: +// - endSessionURL: The OIDC provider's end session endpoint +// - idToken: The ID token to be invalidated +// - postLogoutRedirectURI: Where to redirect after logout completes func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) { u, err := url.Parse(endSessionURL) if err != nil { @@ -408,7 +445,6 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin q := u.Query() q.Set("id_token_hint", idToken) if postLogoutRedirectURI != "" { - // Ensure postLogoutRedirectURI is properly URL encoded q.Set("post_logout_redirect_uri", postLogoutRedirectURI) } u.RawQuery = q.Encode() diff --git a/jwk.go b/jwk.go index b1d096c..57ad2b2 100644 --- a/jwk.go +++ b/jwk.go @@ -16,37 +16,74 @@ import ( "time" ) -// JWK represents a JSON Web Key +// JWK represents a JSON Web Key as defined in RFC 7517. +// It contains the cryptographic key information used for token verification. type JWK struct { + // Kty is the key type (e.g., "RSA", "EC") Kty string `json:"kty"` + + // Kid is the unique key identifier Kid string `json:"kid"` + + // Use specifies the intended use of the key (e.g., "sig" for signature) Use string `json:"use"` - N string `json:"n"` - E string `json:"e"` + + // N is the modulus for RSA keys + N string `json:"n"` + + // E is the exponent for RSA keys + E string `json:"e"` + + // Alg is the algorithm intended for use with the key Alg string `json:"alg"` + + // Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521") Crv string `json:"crv"` - X string `json:"x"` - Y string `json:"y"` + + // X is the x-coordinate for EC keys + X string `json:"x"` + + // Y is the y-coordinate for EC keys + Y string `json:"y"` } -// JWKSet represents a set of JWKs +// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint. +// OIDC providers typically expose multiple keys to support key rotation. type JWKSet struct { + // Keys is the array of JSON Web Keys Keys []JWK `json:"keys"` } -// JWKCache caches the JWKs +// JWKCache provides a thread-safe caching mechanism for JWK sets. +// It caches the keys for a configurable duration to reduce load on the OIDC provider +// while ensuring keys are refreshed periodically to handle key rotation. type JWKCache struct { - jwks *JWKSet + // jwks holds the cached set of JSON Web Keys + jwks *JWKSet + + // expiresAt is the timestamp when the cached keys should be refreshed expiresAt time.Time - mutex sync.RWMutex + + // mutex protects concurrent access to the cache + mutex sync.RWMutex } -// JWKCacheInterface defines the interface for the JWK cache +// JWKCacheInterface defines the interface for JWK caching operations. +// This interface allows for different caching implementations while +// maintaining consistent behavior in the token verification process. type JWKCacheInterface interface { GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) } -// GetJWKS gets the JWKS, either from cache or by fetching it +// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it +// from the OIDC provider. It implements a thread-safe double-checked locking +// pattern to prevent multiple simultaneous fetches of the same keys. +// Parameters: +// - jwksURL: The URL of the JWKS endpoint +// - httpClient: The HTTP client to use for fetching keys +// Returns: +// - The JSON Web Key Set +// - An error if the keys cannot be retrieved or parsed func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { c.mutex.RLock() if c.jwks != nil && time.Now().Before(c.expiresAt) { @@ -73,7 +110,14 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er return jwks, nil } -// fetchJWKS fetches the JWKS from the provider +// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint. +// It handles HTTP communication and JSON parsing of the response. +// Parameters: +// - jwksURL: The URL of the JWKS endpoint +// - httpClient: The HTTP client to use for the request +// Returns: +// - The parsed JSON Web Key Set +// - An error if the request fails or the response is invalid func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { resp, err := httpClient.Get(jwksURL) if err != nil { @@ -93,7 +137,9 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { return &jwks, nil } -// jwkToPEM converts a JWK to PEM format +// jwkToPEM converts a JSON Web Key to PEM format for use with standard +// cryptographic functions. It supports both RSA and EC keys, delegating +// to the appropriate converter based on the key type. func jwkToPEM(jwk *JWK) ([]byte, error) { converter, ok := jwkConverters[jwk.Kty] if !ok { @@ -109,7 +155,9 @@ var jwkConverters = map[string]jwkToPEMConverter{ "EC": ecJWKToPEM, } -// rsaJWKToPEM converts an RSA JWK to PEM +// rsaJWKToPEM converts an RSA JSON Web Key to PEM format. +// It handles base64url decoding of the modulus and exponent, +// constructs an RSA public key, and encodes it in PEM format. func rsaJWKToPEM(jwk *JWK) ([]byte, error) { nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N) if err != nil { @@ -141,7 +189,10 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) { return pubKeyPEM, nil } -// ecJWKToPEM converts an EC JWK to PEM +// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format. +// It supports the P-256, P-384, and P-521 curves as defined in the +// OIDC specification, decoding the x and y coordinates and encoding +// the resulting public key in PEM format. func ecJWKToPEM(jwk *JWK) ([]byte, error) { xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X) if err != nil { diff --git a/jwt.go b/jwt.go index feae914..67e1c6d 100644 --- a/jwt.go +++ b/jwt.go @@ -16,15 +16,31 @@ import ( "time" ) -// JWT represents a JSON Web Token +// JWT represents a JSON Web Token as defined in RFC 7519. +// It contains the three parts of a JWT: header, claims (payload), +// and signature, along with the original token string. type JWT struct { - Header map[string]interface{} - Claims map[string]interface{} + // Header contains the token metadata (algorithm, key ID, etc.) + Header map[string]interface{} + + // Claims contains the token claims (subject, expiration, etc.) + Claims map[string]interface{} + + // Signature contains the raw signature bytes Signature []byte - Token string + + // Token is the original JWT string + Token string } -// parseJWT parses a JWT token string into a JWT struct +// parseJWT parses a JWT token string into a JWT struct. +// It validates the token format and decodes the three parts +// (header, claims, signature) using base64url decoding. +// Parameters: +// - tokenString: The raw JWT token string +// Returns: +// - A parsed JWT struct +// - An error if the token format is invalid or parsing fails func parseJWT(tokenString string) (*JWT, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { @@ -63,7 +79,14 @@ func parseJWT(tokenString string) (*JWT, error) { return jwt, nil } -// Verify verifies the standard claims in the JWT +// Verify validates the standard JWT claims as defined in RFC 7519. +// It checks: +// - issuer (iss) matches the expected issuer URL +// - audience (aud) includes the client ID +// - expiration time (exp) is in the future +// - issued at time (iat) is in the past +// - subject (sub) is present and not empty +// Returns an error if any validation fails. func (j *JWT) Verify(issuerURL, clientID string) error { claims := j.Claims @@ -107,7 +130,13 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return nil } -// verifyAudience verifies the audience claim +// verifyAudience validates the token's audience claim. +// The audience can be either a single string or an array of strings. +// For array audiences, the expected audience must match any one value. +// Parameters: +// - tokenAudience: The audience claim from the token +// - expectedAudience: The expected audience value +// Returns an error if validation fails. func verifyAudience(tokenAudience interface{}, expectedAudience string) error { switch aud := tokenAudience.(type) { case string: @@ -131,7 +160,12 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error { return nil } -// verifyIssuer verifies the issuer claim +// verifyIssuer validates the token's issuer claim. +// The issuer URL must exactly match the expected issuer. +// Parameters: +// - tokenIssuer: The issuer claim from the token +// - expectedIssuer: The expected issuer URL +// Returns an error if validation fails. func verifyIssuer(tokenIssuer, expectedIssuer string) error { if tokenIssuer != expectedIssuer { return fmt.Errorf("invalid issuer") @@ -139,7 +173,11 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error { return nil } -// verifyExpiration checks if the token has expired +// verifyExpiration checks if the token's expiration time has passed. +// The expiration time is compared against the current time. +// Parameters: +// - expiration: The expiration timestamp from the token +// Returns an error if the token has expired. func verifyExpiration(expiration float64) error { expirationTime := time.Unix(int64(expiration), 0) if time.Now().After(expirationTime) { @@ -148,7 +186,12 @@ func verifyExpiration(expiration float64) error { return nil } -// verifyIssuedAt checks if the token was issued in the future +// verifyIssuedAt validates the token's issued-at time. +// Ensures the token wasn't issued in the future, which could +// indicate clock skew or a malicious token. +// Parameters: +// - issuedAt: The issued-at timestamp from the token +// Returns an error if the token was issued in the future. func verifyIssuedAt(issuedAt float64) error { issuedAtTime := time.Unix(int64(issuedAt), 0) if time.Now().Before(issuedAtTime) { @@ -157,7 +200,16 @@ func verifyIssuedAt(issuedAt float64) error { return nil } -// verifySignature verifies the token signature using the provided public key and algorithm +// verifySignature validates the token's cryptographic signature. +// Supports multiple signature algorithms: +// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5) +// - RSA-PSS: PS256, PS384, PS512 +// - ECDSA: ES256, ES384, ES512 +// Parameters: +// - tokenString: The complete JWT token string +// - publicKeyPEM: The PEM-encoded public key for verification +// - alg: The signature algorithm identifier +// Returns an error if signature verification fails. func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { // Split the token into its three parts parts := strings.Split(tokenString, ".") diff --git a/session.go b/session.go index e5bea79..d32ec12 100644 --- a/session.go +++ b/session.go @@ -8,34 +8,54 @@ import ( "github.com/gorilla/sessions" ) +// Cookie names and configuration constants used for session management const ( - mainCookieName = "_raczylo_oidc" // Main session cookie - accessTokenCookie = "_raczylo_oidc_access" // Access token cookie - refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie - maxCookieSize = 2000 // Max size for each chunk to stay within 4096-byte cookie limit + // mainCookieName is the name of the main session cookie that stores authentication state + // and basic user information like email and CSRF tokens + mainCookieName = "_raczylo_oidc" - // REASON: - // Let x be the maximum size of the chunk (maxCookieSize). - // Encrypted size = x + 28 bytes - // Base64-encoded size = ((x + 28) * 4) / 3 bytes - // ((x + 28) * 4) / 3 <= 4096 - // Multiply both sides by 3: - // 4 * (x + 28) <= 4096 * 3 - // 4 * (x + 28) <= 12288 - // Divide both sides by 4: - // x + 28 <= 3072 - // Subtract 28 from both sides: - // x <= 3044 + // accessTokenCookie is the name of the cookie that stores the OIDC access token + // This may be split into multiple cookies if the token is large + accessTokenCookie = "_raczylo_oidc_access" + + // refreshTokenCookie is the name of the cookie that stores the OIDC refresh token + // This may be split into multiple cookies if the token is large + refreshTokenCookie = "_raczylo_oidc_refresh" + + // maxCookieSize is the maximum size for each cookie chunk. + // This value is calculated to ensure the final cookie size stays within browser limits: + // 1. Browser cookie size limit is typically 4096 bytes + // 2. Cookie content undergoes encryption (adds 28 bytes) and base64 encoding (4/3 ratio) + // 3. Calculation: + // - Let x be the chunk size + // - After encryption: x + 28 bytes + // - After base64: ((x + 28) * 4/3) bytes + // - Must satisfy: ((x + 28) * 4/3) ≤ 4096 + // - Solving for x: x ≤ 3044 + // 4. We use 2000 as a conservative limit to account for cookie metadata + maxCookieSize = 2000 ) -// SessionManager handles multiple session cookies +// SessionManager handles the management of multiple session cookies for OIDC authentication. +// It provides functionality for storing and retrieving authentication state, tokens, +// and other session-related data across multiple cookies to handle large tokens. type SessionManager struct { - store sessions.Store + // store is the underlying session store for cookie management + store sessions.Store + + // forceHTTPS enforces secure cookie attributes regardless of request scheme forceHTTPS bool - logger *Logger + + // logger provides structured logging capabilities + logger *Logger } -// NewSessionManager creates a new session manager +// NewSessionManager creates a new session manager with the specified configuration. +// Parameters: +// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes) +// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme +// - logger: Logger instance for recording session-related events +// The manager handles session creation, storage, and cookie security settings. func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager { return &SessionManager{ store: sessions.NewCookieStore([]byte(encryptionKey)), @@ -44,7 +64,14 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S } } -// getSessionOptions returns session options based on scheme +// getSessionOptions returns secure session options configured for the current request. +// Parameters: +// - isSecure: Whether the current request is using HTTPS +// The options ensure cookies are: +// - HTTP-only (not accessible via JavaScript) +// - Secure when using HTTPS or when forceHTTPS is enabled +// - Using SameSite=Lax for CSRF protection +// - Set with appropriate timeout and path settings func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options { return &sessions.Options{ HttpOnly: true, @@ -55,7 +82,10 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options { } } -// GetSession retrieves all session data +// GetSession retrieves all session data for the current request. +// It loads the main session and token sessions, including any chunked token data, +// and combines them into a single SessionData structure for easy access. +// Returns an error if any session component cannot be loaded. func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { mainSession, err := sm.store.Get(r, mainCookieName) if err != nil { @@ -88,7 +118,12 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) { return sessionData, nil } -// getTokenChunkSessions retrieves sessions for token chunks +// getTokenChunkSessions retrieves all session chunks for a given token type. +// Parameters: +// - r: The HTTP request +// - baseName: The base name for the token's session cookies +// Returns a map of chunk index to session, used for handling large tokens +// that exceed single cookie size limits. func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session { chunks := make(map[int]*sessions.Session) for i := 0; ; i++ { @@ -103,18 +138,39 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string return chunks } -// SessionData holds all session information +// SessionData holds all session information for an authenticated user. +// It manages multiple session cookies to handle the main session state +// and potentially large access and refresh tokens that may need to be +// split across multiple cookies due to browser size limitations. type SessionData struct { - manager *SessionManager - request *http.Request - mainSession *sessions.Session - accessSession *sessions.Session - refreshSession *sessions.Session - accessTokenChunks map[int]*sessions.Session + // manager is the SessionManager that created this SessionData + manager *SessionManager + + // request is the current HTTP request associated with this session + request *http.Request + + // mainSession stores authentication state and basic user info + mainSession *sessions.Session + + // accessSession stores the primary access token cookie + accessSession *sessions.Session + + // refreshSession stores the primary refresh token cookie + refreshSession *sessions.Session + + // accessTokenChunks stores additional chunks of the access token + // when it exceeds the maximum cookie size + accessTokenChunks map[int]*sessions.Session + + // refreshTokenChunks stores additional chunks of the refresh token + // when it exceeds the maximum cookie size refreshTokenChunks map[int]*sessions.Session } -// Save saves all session data +// Save persists all session data to cookies in the HTTP response. +// It saves the main session, token sessions, and any token chunks, +// applying appropriate security options to each cookie. All cookies +// are saved with consistent security settings based on the request scheme. func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS @@ -158,7 +214,9 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error { return nil } -// Clear clears all session data +// Clear removes all session data by expiring all cookies and clearing their values. +// This is typically used during logout to ensure all session data is properly cleaned up. +// It handles both main session data and any token chunks that may exist. func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { // Clear and expire all sessions sd.mainSession.Options.MaxAge = -1 @@ -182,7 +240,9 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { return sd.Save(r, w) } -// clearTokenChunks clears chunked token sessions +// clearTokenChunks removes all session chunks for a given token type. +// It expires the cookies and removes all stored values to ensure +// no token data remains after logout or token invalidation. func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) { for _, session := range chunks { session.Options.MaxAge = -1 @@ -192,18 +252,24 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session } } -// GetAuthenticated returns authentication status +// GetAuthenticated returns whether the current session is authenticated. +// Returns true if the user has successfully completed OIDC authentication, +// false otherwise or if the authentication status cannot be determined. func (sd *SessionData) GetAuthenticated() bool { auth, _ := sd.mainSession.Values["authenticated"].(bool) return auth } -// SetAuthenticated sets authentication status +// SetAuthenticated updates the session's authentication status. +// This should be called after successful OIDC authentication or during logout. func (sd *SessionData) SetAuthenticated(value bool) { sd.mainSession.Values["authenticated"] = value } -// GetAccessToken returns the access token +// GetAccessToken retrieves the complete access token from the session. +// If the token was split into chunks due to size limitations, it will +// automatically reassemble the complete token from all chunks. +// Returns an empty string if no token is found. func (sd *SessionData) GetAccessToken() string { token, _ := sd.accessSession.Values["token"].(string) if token != "" { @@ -228,7 +294,11 @@ func (sd *SessionData) GetAccessToken() string { return strings.Join(chunks, "") } -// SetAccessToken sets the access token +// SetAccessToken stores the access token in the session. +// If the token exceeds maxCookieSize, it is automatically split into +// multiple cookie chunks to handle large tokens while staying within +// browser cookie size limits. Any existing token or chunks are cleared +// before setting the new token. func (sd *SessionData) SetAccessToken(token string) { // Clear existing chunks sd.clearTokenChunks(sd.request, sd.accessTokenChunks) @@ -249,7 +319,10 @@ func (sd *SessionData) SetAccessToken(token string) { } } -// GetRefreshToken returns the refresh token +// GetRefreshToken retrieves the complete refresh token from the session. +// If the token was split into chunks due to size limitations, it will +// automatically reassemble the complete token from all chunks. +// Returns an empty string if no token is found. func (sd *SessionData) GetRefreshToken() string { token, _ := sd.refreshSession.Values["token"].(string) if token != "" { @@ -274,7 +347,11 @@ func (sd *SessionData) GetRefreshToken() string { return strings.Join(chunks, "") } -// SetRefreshToken sets the refresh token +// SetRefreshToken stores the refresh token in the session. +// If the token exceeds maxCookieSize, it is automatically split into +// multiple cookie chunks to handle large tokens while staying within +// browser cookie size limits. Any existing token or chunks are cleared +// before setting the new token. func (sd *SessionData) SetRefreshToken(token string) { // Clear existing chunks sd.clearTokenChunks(sd.request, sd.refreshTokenChunks) @@ -295,7 +372,12 @@ func (sd *SessionData) SetRefreshToken(token string) { } } -// splitIntoChunks splits a string into chunks of specified size +// splitIntoChunks splits a string into chunks of specified size. +// This is used internally to handle large tokens that exceed cookie size limits. +// Parameters: +// - s: The string to split +// - chunkSize: Maximum size of each chunk +// Returns an array of string chunks, each no larger than chunkSize. func splitIntoChunks(s string, chunkSize int) []string { var chunks []string for len(s) > 0 { @@ -310,46 +392,65 @@ func splitIntoChunks(s string, chunkSize int) []string { return chunks } -// GetCSRF returns the CSRF token +// GetCSRF retrieves the CSRF token from the session. +// This token is used to prevent cross-site request forgery attacks +// by ensuring requests originate from the authenticated user. +// Returns an empty string if no CSRF token is found. func (sd *SessionData) GetCSRF() string { csrf, _ := sd.mainSession.Values["csrf"].(string) return csrf } -// SetCSRF sets the CSRF token +// SetCSRF stores a new CSRF token in the session. +// This should be called when initiating authentication to generate +// a new token for the authentication flow. func (sd *SessionData) SetCSRF(token string) { sd.mainSession.Values["csrf"] = token } -// GetNonce returns the nonce +// GetNonce retrieves the nonce value from the session. +// The nonce is used to prevent replay attacks in the OIDC flow +// by ensuring the token received matches the authentication request. +// Returns an empty string if no nonce is found. func (sd *SessionData) GetNonce() string { nonce, _ := sd.mainSession.Values["nonce"].(string) return nonce } -// SetNonce sets the nonce +// SetNonce stores a new nonce value in the session. +// This should be called when initiating authentication to generate +// a new nonce for the OIDC authentication flow. func (sd *SessionData) SetNonce(nonce string) { sd.mainSession.Values["nonce"] = nonce } -// GetEmail returns the user's email +// GetEmail retrieves the authenticated user's email address from the session. +// The email is typically extracted from the OIDC ID token claims. +// Returns an empty string if no email is found. func (sd *SessionData) GetEmail() string { email, _ := sd.mainSession.Values["email"].(string) return email } -// SetEmail sets the user's email +// SetEmail stores the user's email address in the session. +// This should be called after successful authentication when +// processing the OIDC ID token claims. func (sd *SessionData) SetEmail(email string) { sd.mainSession.Values["email"] = email } -// GetIncomingPath returns the original incoming path +// GetIncomingPath retrieves the original request path that triggered +// the authentication flow. This is used to redirect the user back +// to their intended destination after successful authentication. +// Returns an empty string if no path was stored. func (sd *SessionData) GetIncomingPath() string { path, _ := sd.mainSession.Values["incoming_path"].(string) return path } -// SetIncomingPath sets the original incoming path +// SetIncomingPath stores the original request path that triggered +// the authentication flow. This should be called before redirecting +// to the OIDC provider to remember where to send the user afterward. func (sd *SessionData) SetIncomingPath(path string) { sd.mainSession.Values["incoming_path"] = path } diff --git a/settings.go b/settings.go index e3d37de..95cbaaf 100644 --- a/settings.go +++ b/settings.go @@ -5,35 +5,93 @@ import ( "io" "log" "net/http" + "net/url" "os" + "strings" ) const ( cookieName = "_raczylo_oidc" ) -// Config holds the configuration for the OIDC middleware +// Config holds the configuration for the OIDC middleware. +// It provides all necessary settings to configure OpenID Connect authentication +// with various providers like Auth0, Logto, or any standard OIDC provider. type Config struct { - ProviderURL string `json:"providerURL"` - RevocationURL string `json:"revocationURL"` - CallbackURL string `json:"callbackURL"` - LogoutURL string `json:"logoutURL"` - ClientID string `json:"clientID"` - ClientSecret string `json:"clientSecret"` - Scopes []string `json:"scopes"` - LogLevel string `json:"logLevel"` - SessionEncryptionKey string `json:"sessionEncryptionKey"` - ForceHTTPS bool `json:"forceHTTPS"` - RateLimit int `json:"rateLimit"` - ExcludedURLs []string `json:"excludedURLs"` - AllowedUserDomains []string `json:"allowedUserDomains"` + // ProviderURL is the base URL of the OIDC provider (required) + // Example: https://accounts.google.com + ProviderURL string `json:"providerURL"` + + // RevocationURL is the endpoint for revoking tokens (optional) + // If not provided, it will be discovered from provider metadata + RevocationURL string `json:"revocationURL"` + + // CallbackURL is the path where the OIDC provider will redirect after authentication (required) + // Example: /oauth2/callback + CallbackURL string `json:"callbackURL"` + + // LogoutURL is the path for handling logout requests (optional) + // If not provided, it will be set to CallbackURL + "/logout" + LogoutURL string `json:"logoutURL"` + + // ClientID is the OAuth 2.0 client identifier (required) + ClientID string `json:"clientID"` + + // ClientSecret is the OAuth 2.0 client secret (required) + ClientSecret string `json:"clientSecret"` + + // Scopes defines the OAuth 2.0 scopes to request (optional) + // Defaults to ["openid", "profile", "email"] if not provided + Scopes []string `json:"scopes"` + + // LogLevel sets the logging verbosity (optional) + // Valid values: "debug", "info", "error" + // Default: "info" + LogLevel string `json:"logLevel"` + + // SessionEncryptionKey is used to encrypt session data (required) + // Must be a secure random string + SessionEncryptionKey string `json:"sessionEncryptionKey"` + + // ForceHTTPS forces the use of HTTPS for all URLs (optional) + // Default: false + ForceHTTPS bool `json:"forceHTTPS"` + + // RateLimit sets the maximum number of requests per second (optional) + // Default: 100 + RateLimit int `json:"rateLimit"` + + // ExcludedURLs lists paths that bypass authentication (optional) + // Example: ["/health", "/metrics"] + ExcludedURLs []string `json:"excludedURLs"` + + // AllowedUserDomains restricts access to specific email domains (optional) + // Example: ["company.com", "subsidiary.com"] + AllowedUserDomains []string `json:"allowedUserDomains"` + + // AllowedRolesAndGroups restricts access to users with specific roles or groups (optional) + // Example: ["admin", "developer"] AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"` - OIDCEndSessionURL string `json:"oidcEndSessionURL"` - PostLogoutRedirectURI string `json:"postLogoutRedirectURI"` - HTTPClient *http.Client + + // OIDCEndSessionURL is the provider's end session endpoint (optional) + // If not provided, it will be discovered from provider metadata + OIDCEndSessionURL string `json:"oidcEndSessionURL"` + + // PostLogoutRedirectURI is the URL to redirect to after logout (optional) + // Default: "/" + PostLogoutRedirectURI string `json:"postLogoutRedirectURI"` + + // HTTPClient allows customizing the HTTP client used for OIDC operations (optional) + HTTPClient *http.Client } -// CreateConfig creates a new Config with default values +// CreateConfig creates a new Config with sensible default values. +// Default values are set for optional fields: +// - Scopes: ["openid", "profile", "email"] +// - LogLevel: "info" +// - LogoutURL: CallbackURL + "/logout" +// - RateLimit: 100 requests per second +// - PostLogoutRedirectURI: "/" func CreateConfig() *Config { c := &Config{} @@ -56,14 +114,22 @@ func CreateConfig() *Config { return c } -// Validate validates the Config +// Validate performs validation checks on the Config. +// It ensures all required fields are set and have valid values. +// Returns an error if any validation check fails. func (c *Config) Validate() error { if c.ProviderURL == "" { return fmt.Errorf("providerURL is required") } + if !isValidURL(c.ProviderURL) { + return fmt.Errorf("providerURL must be a valid URL") + } if c.CallbackURL == "" { return fmt.Errorf("callbackURL is required") } + if !strings.HasPrefix(c.CallbackURL, "/") { + return fmt.Errorf("callbackURL must start with /") + } if c.ClientID == "" { return fmt.Errorf("clientID is required") } @@ -73,17 +139,48 @@ func (c *Config) Validate() error { if c.SessionEncryptionKey == "" { return fmt.Errorf("sessionEncryptionKey is required") } + if len(c.SessionEncryptionKey) < 32 { + return fmt.Errorf("sessionEncryptionKey must be at least 32 characters long") + } + if c.RateLimit < 0 { + return fmt.Errorf("rateLimit must be non-negative") + } + if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) { + return fmt.Errorf("logLevel must be one of: debug, info, error") + } return nil } -// Logger is a simple logger with different levels +// isValidURL checks if the provided string is a valid URL +func isValidURL(s string) bool { + u, err := url.Parse(s) + return err == nil && u.Scheme != "" && u.Host != "" +} + +// isValidLogLevel checks if the provided log level is valid +func isValidLogLevel(level string) bool { + return level == "debug" || level == "info" || level == "error" +} + +// Logger provides structured logging capabilities with different severity levels. +// It supports error, info, and debug levels with appropriate output streams +// and formatting for each level. type Logger struct { + // logError handles error-level messages, writing to stderr logError *log.Logger - logInfo *log.Logger + // logInfo handles informational messages, writing to stdout + logInfo *log.Logger + // logDebug handles debug-level messages, writing to stdout when debug is enabled logDebug *log.Logger } -// NewLogger creates a new Logger +// NewLogger creates a new Logger with the specified log level. +// The log level determines which messages are output: +// - "debug": Outputs all messages (debug, info, error) +// - "info": Outputs info and error messages +// - "error": Outputs only error messages +// Error messages are always written to stderr, while info and debug +// messages are written to stdout when enabled. func NewLogger(logLevel string) *Logger { logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime) logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime) @@ -103,37 +200,51 @@ func NewLogger(logLevel string) *Logger { } } -// Info logs an info message +// Info logs an informational message. +// These messages are intended for general operational information +// and are written to stdout. func (l *Logger) Info(format string, args ...interface{}) { l.logInfo.Printf(format, args...) } -// Debug logs a debug message +// Debug logs a debug message. +// These messages are only output when debug level logging is enabled +// and are intended for detailed troubleshooting information. func (l *Logger) Debug(format string, args ...interface{}) { l.logDebug.Printf(format, args...) } -// Error logs an error message +// Error logs an error message. +// These messages indicate problems that need attention and are +// always written to stderr regardless of the log level. func (l *Logger) Error(format string, args ...interface{}) { l.logError.Printf(format, args...) } -// Infof logs an info message +// Infof logs an informational message using Printf formatting. +// These messages are intended for general operational information +// and are written to stdout. func (l *Logger) Infof(format string, args ...interface{}) { l.logInfo.Printf(format, args...) } -// Debugf logs a debug message +// Debugf logs a debug message using Printf formatting. +// These messages are only output when debug level logging is enabled +// and are intended for detailed troubleshooting information. func (l *Logger) Debugf(format string, args ...interface{}) { l.logDebug.Printf(format, args...) } -// Errorf logs an error message +// Errorf logs an error message using Printf formatting. +// These messages indicate problems that need attention and are +// always written to stderr regardless of the log level. func (l *Logger) Errorf(format string, args ...interface{}) { l.logError.Printf(format, args...) } -// handleError writes an error message to the response and logs it +// handleError writes an error message to both the HTTP response and the error log. +// It ensures consistent error handling across the middleware by logging the error +// and sending an appropriate HTTP response to the client. func handleError(w http.ResponseWriter, message string, code int, logger *Logger) { logger.Error(message) http.Error(w, message, code)