mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Improve documentation.
This commit is contained in:
@@ -5,26 +5,41 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheItem represents an item in the cache
|
||||
// CacheItem represents an item stored in the cache with its associated metadata.
|
||||
type CacheItem struct {
|
||||
Value interface{}
|
||||
// Value is the cached data of any type
|
||||
Value interface{}
|
||||
|
||||
// ExpiresAt is the timestamp when this item should be considered expired
|
||||
// and removed from the cache during cleanup operations
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// Cache is a simple in-memory cache
|
||||
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
||||
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
|
||||
type Cache struct {
|
||||
// items stores the cached data with string keys
|
||||
items map[string]CacheItem
|
||||
|
||||
// mutex protects concurrent access to the items map
|
||||
// Use RLock/RUnlock for reads and Lock/Unlock for writes
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCache creates a new Cache
|
||||
// NewCache creates a new empty cache instance.
|
||||
// The cache is immediately ready for use and is thread-safe.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
items: make(map[string]CacheItem),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds an item to the cache
|
||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||
// Parameters:
|
||||
// - key: Unique identifier for the cached item
|
||||
// - value: The data to cache (can be of any type)
|
||||
// - expiration: How long the item should remain in the cache
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -34,7 +49,13 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
// Get retrieves an item from the cache if it exists and hasn't expired.
|
||||
// Parameters:
|
||||
// - key: The identifier of the item to retrieve
|
||||
// Returns:
|
||||
// - value: The cached data (nil if not found or expired)
|
||||
// - found: true if the item was found and is valid, false otherwise
|
||||
// Thread-safe: Uses read locking to ensure safe concurrent access.
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
@@ -49,14 +70,19 @@ func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
// Delete removes an item from the cache if it exists.
|
||||
// If the item doesn't exist, this operation is a no-op.
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
// Cleanup removes expired items from the cache
|
||||
// Cleanup removes all expired items from the cache.
|
||||
// This should be called periodically to prevent memory leaks from
|
||||
// expired items that haven't been accessed (and thus not removed during Get operations).
|
||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
+85
-49
@@ -16,6 +16,13 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// newSessionOptions creates secure session cookie options.
|
||||
// Parameters:
|
||||
// - isSecure: Whether to set the Secure flag on cookies
|
||||
// Returns session options configured for security with:
|
||||
// - HttpOnly flag to prevent JavaScript access
|
||||
// - SameSite=Lax for CSRF protection
|
||||
// - Appropriate timeout and path settings
|
||||
func newSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
@@ -26,7 +33,10 @@ func newSessionOptions(isSecure bool) *sessions.Options {
|
||||
}
|
||||
}
|
||||
|
||||
// generateNonce generates a random nonce
|
||||
// generateNonce creates a cryptographically secure random nonce
|
||||
// for use in the OIDC authentication flow. The nonce is used to
|
||||
// prevent replay attacks by ensuring the token received matches
|
||||
// the authentication request.
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
@@ -36,7 +46,33 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// exchangeTokens exchanges a code or refresh token for tokens
|
||||
// TokenResponse represents the response from the OIDC token endpoint.
|
||||
// It contains the various tokens and metadata returned after successful
|
||||
// code exchange or token refresh operations.
|
||||
type TokenResponse struct {
|
||||
// IDToken is the OIDC ID token containing user claims
|
||||
IDToken string `json:"id_token"`
|
||||
|
||||
// AccessToken is the OAuth 2.0 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
|
||||
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
|
||||
// ExpiresIn is the lifetime in seconds of the access token
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// TokenType is the type of token, typically "Bearer"
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
|
||||
// It supports both authorization code and refresh token grant types.
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
||||
// - codeOrToken: Either the authorization code or refresh token
|
||||
// - redirectURL: The callback URL for authorization code grant
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
@@ -76,16 +112,8 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from the token endpoint
|
||||
type TokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// getNewTokenWithRefreshToken refreshes the token using the refresh token
|
||||
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
|
||||
// This is used to refresh access tokens before they expire.
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
@@ -94,24 +122,23 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
}
|
||||
|
||||
t.logger.Debugf("Token response: %+v", tokenResponse)
|
||||
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// handleExpiredToken handles the case when a token has expired
|
||||
// handleExpiredToken manages token expiration by clearing the session
|
||||
// and initiating a new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear the existing session
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to clear session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize new authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback handles the callback from the OIDC provider
|
||||
// handleCallback processes the authentication callback from the OIDC provider.
|
||||
// It validates the callback parameters, exchanges the authorization code for
|
||||
// tokens, verifies the tokens, and establishes the user's session.
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -122,7 +149,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Check for errors in the query parameters
|
||||
// Check for errors in the callback
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
@@ -130,7 +157,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state parameter matches the session's CSRF token
|
||||
// Validate CSRF state
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
@@ -166,7 +193,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Verify and process tokens
|
||||
// Verify tokens and claims
|
||||
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
@@ -180,7 +207,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Verify nonce
|
||||
// Verify nonce to prevent replay attacks
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
@@ -201,7 +228,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Process email
|
||||
// Validate user's email domain
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
@@ -209,13 +236,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Update session with new values
|
||||
// Update session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
@@ -231,7 +257,8 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// extractClaims extracts claims from a JWT token
|
||||
// extractClaims parses a JWT token and extracts its claims.
|
||||
// It handles base64url decoding and JSON parsing of the token payload.
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -251,27 +278,32 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// TokenBlacklist maintains a blacklist of tokens
|
||||
// TokenBlacklist maintains a thread-safe list of revoked tokens.
|
||||
// It stores tokens with their expiration times and automatically
|
||||
// removes expired entries during cleanup operations.
|
||||
type TokenBlacklist struct {
|
||||
// blacklist maps token IDs to their expiration times
|
||||
blacklist map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
|
||||
// mutex protects concurrent access to the blacklist
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new TokenBlacklist
|
||||
// NewTokenBlacklist creates a new TokenBlacklist instance.
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist
|
||||
// Add adds a token to the blacklist with an expiration time.
|
||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
tb.blacklist[tokenID] = expiration
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is blacklisted
|
||||
// IsBlacklisted checks if a token is in the blacklist and not expired.
|
||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
@@ -279,7 +311,7 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
return exists && time.Now().Before(expiration)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist
|
||||
// Cleanup removes expired tokens from the blacklist.
|
||||
func (tb *TokenBlacklist) Cleanup() {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
@@ -291,25 +323,29 @@ func (tb *TokenBlacklist) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCache caches tokens
|
||||
// TokenCache provides a caching mechanism for validated tokens.
|
||||
// It stores token claims to avoid repeated validation of the
|
||||
// same token, improving performance for frequently used tokens.
|
||||
type TokenCache struct {
|
||||
// cache is the underlying cache implementation
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewTokenCache creates a new TokenCache
|
||||
// NewTokenCache creates a new TokenCache instance.
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
cache: NewCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a token in the cache
|
||||
// Set stores a token's claims in the cache with an expiration time.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
|
||||
// Get retrieves a token from the cache
|
||||
// Get retrieves a token's claims from the cache.
|
||||
// Returns the claims and a boolean indicating if the token was found.
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
@@ -320,18 +356,18 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// Delete removes a token from the cache
|
||||
// Delete removes a token from the cache.
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
}
|
||||
|
||||
// Cleanup cleans up expired tokens from the cache
|
||||
// Cleanup removes expired tokens from the cache.
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.cache.Cleanup()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens
|
||||
// exchangeCodeForToken exchanges an authorization code for tokens.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
||||
@@ -341,7 +377,8 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*To
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// createStringMap creates a map from a slice of strings
|
||||
// createStringMap creates a map from a slice of strings.
|
||||
// Used for efficient lookups in allowed domains and roles.
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
@@ -350,7 +387,9 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout handles the logout request
|
||||
// handleLogout manages the OIDC logout process.
|
||||
// It clears the session and redirects either to the OIDC provider's
|
||||
// end session endpoint (if available) or to the configured post-logout URL.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -359,22 +398,18 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get the access token before clearing session
|
||||
accessToken := session.GetAccessToken()
|
||||
|
||||
// Clear all session data
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the base URL for redirects
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
// Determine post logout redirect URI
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
@@ -382,7 +417,6 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
// If we have an end session endpoint and an access token, use OIDC end session
|
||||
if t.endSessionURL != "" && accessToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
@@ -394,11 +428,14 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, redirect to post logout URI
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL
|
||||
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
|
||||
// Parameters:
|
||||
// - endSessionURL: The OIDC provider's end session endpoint
|
||||
// - idToken: The ID token to be invalidated
|
||||
// - postLogoutRedirectURI: Where to redirect after logout completes
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
@@ -408,7 +445,6 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
if postLogoutRedirectURI != "" {
|
||||
// Ensure postLogoutRedirectURI is properly URL encoded
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
@@ -16,37 +16,74 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key
|
||||
// JWK represents a JSON Web Key as defined in RFC 7517.
|
||||
// It contains the cryptographic key information used for token verification.
|
||||
type JWK struct {
|
||||
// Kty is the key type (e.g., "RSA", "EC")
|
||||
Kty string `json:"kty"`
|
||||
|
||||
// Kid is the unique key identifier
|
||||
Kid string `json:"kid"`
|
||||
|
||||
// Use specifies the intended use of the key (e.g., "sig" for signature)
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
|
||||
// N is the modulus for RSA keys
|
||||
N string `json:"n"`
|
||||
|
||||
// E is the exponent for RSA keys
|
||||
E string `json:"e"`
|
||||
|
||||
// Alg is the algorithm intended for use with the key
|
||||
Alg string `json:"alg"`
|
||||
|
||||
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
|
||||
Crv string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
Y string `json:"y"`
|
||||
|
||||
// X is the x-coordinate for EC keys
|
||||
X string `json:"x"`
|
||||
|
||||
// Y is the y-coordinate for EC keys
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// JWKSet represents a set of JWKs
|
||||
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
|
||||
// OIDC providers typically expose multiple keys to support key rotation.
|
||||
type JWKSet struct {
|
||||
// Keys is the array of JSON Web Keys
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache caches the JWKs
|
||||
// JWKCache provides a thread-safe caching mechanism for JWK sets.
|
||||
// It caches the keys for a configurable duration to reduce load on the OIDC provider
|
||||
// while ensuring keys are refreshed periodically to handle key rotation.
|
||||
type JWKCache struct {
|
||||
jwks *JWKSet
|
||||
// jwks holds the cached set of JSON Web Keys
|
||||
jwks *JWKSet
|
||||
|
||||
// expiresAt is the timestamp when the cached keys should be refreshed
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
|
||||
// mutex protects concurrent access to the cache
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the interface for the JWK cache
|
||||
// JWKCacheInterface defines the interface for JWK caching operations.
|
||||
// This interface allows for different caching implementations while
|
||||
// maintaining consistent behavior in the token verification process.
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
}
|
||||
|
||||
// GetJWKS gets the JWKS, either from cache or by fetching it
|
||||
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
|
||||
// from the OIDC provider. It implements a thread-safe double-checked locking
|
||||
// pattern to prevent multiple simultaneous fetches of the same keys.
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for fetching keys
|
||||
// Returns:
|
||||
// - The JSON Web Key Set
|
||||
// - An error if the keys cannot be retrieved or parsed
|
||||
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
@@ -73,7 +110,14 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// fetchJWKS fetches the JWKS from the provider
|
||||
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
|
||||
// It handles HTTP communication and JSON parsing of the response.
|
||||
// Parameters:
|
||||
// - jwksURL: The URL of the JWKS endpoint
|
||||
// - httpClient: The HTTP client to use for the request
|
||||
// Returns:
|
||||
// - The parsed JSON Web Key Set
|
||||
// - An error if the request fails or the response is invalid
|
||||
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
resp, err := httpClient.Get(jwksURL)
|
||||
if err != nil {
|
||||
@@ -93,7 +137,9 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JWK to PEM format
|
||||
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
|
||||
// cryptographic functions. It supports both RSA and EC keys, delegating
|
||||
// to the appropriate converter based on the key type.
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
@@ -109,7 +155,9 @@ var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JWK to PEM
|
||||
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
|
||||
// It handles base64url decoding of the modulus and exponent,
|
||||
// constructs an RSA public key, and encodes it in PEM format.
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
@@ -141,7 +189,10 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC JWK to PEM
|
||||
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
|
||||
// It supports the P-256, P-384, and P-521 curves as defined in the
|
||||
// OIDC specification, decoding the x and y coordinates and encoding
|
||||
// the resulting public key in PEM format.
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
|
||||
@@ -16,15 +16,31 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
// It contains the three parts of a JWT: header, claims (payload),
|
||||
// and signature, along with the original token string.
|
||||
type JWT struct {
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
// Header contains the token metadata (algorithm, key ID, etc.)
|
||||
Header map[string]interface{}
|
||||
|
||||
// Claims contains the token claims (subject, expiration, etc.)
|
||||
Claims map[string]interface{}
|
||||
|
||||
// Signature contains the raw signature bytes
|
||||
Signature []byte
|
||||
Token string
|
||||
|
||||
// Token is the original JWT string
|
||||
Token string
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct
|
||||
// parseJWT parses a JWT token string into a JWT struct.
|
||||
// It validates the token format and decodes the three parts
|
||||
// (header, claims, signature) using base64url decoding.
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT token string
|
||||
// Returns:
|
||||
// - A parsed JWT struct
|
||||
// - An error if the token format is invalid or parsing fails
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -63,7 +79,14 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
// Verify verifies the standard claims in the JWT
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// It checks:
|
||||
// - issuer (iss) matches the expected issuer URL
|
||||
// - audience (aud) includes the client ID
|
||||
// - expiration time (exp) is in the future
|
||||
// - issued at time (iat) is in the past
|
||||
// - subject (sub) is present and not empty
|
||||
// Returns an error if any validation fails.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
claims := j.Claims
|
||||
|
||||
@@ -107,7 +130,13 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience verifies the audience claim
|
||||
// verifyAudience validates the token's audience claim.
|
||||
// The audience can be either a single string or an array of strings.
|
||||
// For array audiences, the expected audience must match any one value.
|
||||
// Parameters:
|
||||
// - tokenAudience: The audience claim from the token
|
||||
// - expectedAudience: The expected audience value
|
||||
// Returns an error if validation fails.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
@@ -131,7 +160,12 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer verifies the issuer claim
|
||||
// verifyIssuer validates the token's issuer claim.
|
||||
// The issuer URL must exactly match the expected issuer.
|
||||
// Parameters:
|
||||
// - tokenIssuer: The issuer claim from the token
|
||||
// - expectedIssuer: The expected issuer URL
|
||||
// Returns an error if validation fails.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
@@ -139,7 +173,11 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyExpiration checks if the token has expired
|
||||
// verifyExpiration checks if the token's expiration time has passed.
|
||||
// The expiration time is compared against the current time.
|
||||
// Parameters:
|
||||
// - expiration: The expiration timestamp from the token
|
||||
// Returns an error if the token has expired.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
if time.Now().After(expirationTime) {
|
||||
@@ -148,7 +186,12 @@ func verifyExpiration(expiration float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuedAt checks if the token was issued in the future
|
||||
// verifyIssuedAt validates the token's issued-at time.
|
||||
// Ensures the token wasn't issued in the future, which could
|
||||
// indicate clock skew or a malicious token.
|
||||
// Parameters:
|
||||
// - issuedAt: The issued-at timestamp from the token
|
||||
// Returns an error if the token was issued in the future.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
if time.Now().Before(issuedAtTime) {
|
||||
@@ -157,7 +200,16 @@ func verifyIssuedAt(issuedAt float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifySignature verifies the token signature using the provided public key and algorithm
|
||||
// verifySignature validates the token's cryptographic signature.
|
||||
// Supports multiple signature algorithms:
|
||||
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
|
||||
// - RSA-PSS: PS256, PS384, PS512
|
||||
// - ECDSA: ES256, ES384, ES512
|
||||
// Parameters:
|
||||
// - tokenString: The complete JWT token string
|
||||
// - publicKeyPEM: The PEM-encoded public key for verification
|
||||
// - alg: The signature algorithm identifier
|
||||
// Returns an error if signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
|
||||
+149
-48
@@ -8,34 +8,54 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// Cookie names and configuration constants used for session management
|
||||
const (
|
||||
mainCookieName = "_raczylo_oidc" // Main session cookie
|
||||
accessTokenCookie = "_raczylo_oidc_access" // Access token cookie
|
||||
refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie
|
||||
maxCookieSize = 2000 // Max size for each chunk to stay within 4096-byte cookie limit
|
||||
// mainCookieName is the name of the main session cookie that stores authentication state
|
||||
// and basic user information like email and CSRF tokens
|
||||
mainCookieName = "_raczylo_oidc"
|
||||
|
||||
// REASON:
|
||||
// Let x be the maximum size of the chunk (maxCookieSize).
|
||||
// Encrypted size = x + 28 bytes
|
||||
// Base64-encoded size = ((x + 28) * 4) / 3 bytes
|
||||
// ((x + 28) * 4) / 3 <= 4096
|
||||
// Multiply both sides by 3:
|
||||
// 4 * (x + 28) <= 4096 * 3
|
||||
// 4 * (x + 28) <= 12288
|
||||
// Divide both sides by 4:
|
||||
// x + 28 <= 3072
|
||||
// Subtract 28 from both sides:
|
||||
// x <= 3044
|
||||
// accessTokenCookie is the name of the cookie that stores the OIDC access token
|
||||
// This may be split into multiple cookies if the token is large
|
||||
accessTokenCookie = "_raczylo_oidc_access"
|
||||
|
||||
// refreshTokenCookie is the name of the cookie that stores the OIDC refresh token
|
||||
// This may be split into multiple cookies if the token is large
|
||||
refreshTokenCookie = "_raczylo_oidc_refresh"
|
||||
|
||||
// maxCookieSize is the maximum size for each cookie chunk.
|
||||
// This value is calculated to ensure the final cookie size stays within browser limits:
|
||||
// 1. Browser cookie size limit is typically 4096 bytes
|
||||
// 2. Cookie content undergoes encryption (adds 28 bytes) and base64 encoding (4/3 ratio)
|
||||
// 3. Calculation:
|
||||
// - Let x be the chunk size
|
||||
// - After encryption: x + 28 bytes
|
||||
// - After base64: ((x + 28) * 4/3) bytes
|
||||
// - Must satisfy: ((x + 28) * 4/3) ≤ 4096
|
||||
// - Solving for x: x ≤ 3044
|
||||
// 4. We use 2000 as a conservative limit to account for cookie metadata
|
||||
maxCookieSize = 2000
|
||||
)
|
||||
|
||||
// SessionManager handles multiple session cookies
|
||||
// SessionManager handles the management of multiple session cookies for OIDC authentication.
|
||||
// It provides functionality for storing and retrieving authentication state, tokens,
|
||||
// and other session-related data across multiple cookies to handle large tokens.
|
||||
type SessionManager struct {
|
||||
store sessions.Store
|
||||
// store is the underlying session store for cookie management
|
||||
store sessions.Store
|
||||
|
||||
// forceHTTPS enforces secure cookie attributes regardless of request scheme
|
||||
forceHTTPS bool
|
||||
logger *Logger
|
||||
|
||||
// logger provides structured logging capabilities
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager
|
||||
// NewSessionManager creates a new session manager with the specified configuration.
|
||||
// Parameters:
|
||||
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
|
||||
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
|
||||
// - logger: Logger instance for recording session-related events
|
||||
// The manager handles session creation, storage, and cookie security settings.
|
||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
|
||||
return &SessionManager{
|
||||
store: sessions.NewCookieStore([]byte(encryptionKey)),
|
||||
@@ -44,7 +64,14 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S
|
||||
}
|
||||
}
|
||||
|
||||
// getSessionOptions returns session options based on scheme
|
||||
// getSessionOptions returns secure session options configured for the current request.
|
||||
// Parameters:
|
||||
// - isSecure: Whether the current request is using HTTPS
|
||||
// The options ensure cookies are:
|
||||
// - HTTP-only (not accessible via JavaScript)
|
||||
// - Secure when using HTTPS or when forceHTTPS is enabled
|
||||
// - Using SameSite=Lax for CSRF protection
|
||||
// - Set with appropriate timeout and path settings
|
||||
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
@@ -55,7 +82,10 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
}
|
||||
}
|
||||
|
||||
// GetSession retrieves all session data
|
||||
// GetSession retrieves all session data for the current request.
|
||||
// It loads the main session and token sessions, including any chunked token data,
|
||||
// and combines them into a single SessionData structure for easy access.
|
||||
// Returns an error if any session component cannot be loaded.
|
||||
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
mainSession, err := sm.store.Get(r, mainCookieName)
|
||||
if err != nil {
|
||||
@@ -88,7 +118,12 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// getTokenChunkSessions retrieves sessions for token chunks
|
||||
// getTokenChunkSessions retrieves all session chunks for a given token type.
|
||||
// Parameters:
|
||||
// - r: The HTTP request
|
||||
// - baseName: The base name for the token's session cookies
|
||||
// Returns a map of chunk index to session, used for handling large tokens
|
||||
// that exceed single cookie size limits.
|
||||
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session {
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
for i := 0; ; i++ {
|
||||
@@ -103,18 +138,39 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
|
||||
return chunks
|
||||
}
|
||||
|
||||
// SessionData holds all session information
|
||||
// SessionData holds all session information for an authenticated user.
|
||||
// It manages multiple session cookies to handle the main session state
|
||||
// and potentially large access and refresh tokens that may need to be
|
||||
// split across multiple cookies due to browser size limitations.
|
||||
type SessionData struct {
|
||||
manager *SessionManager
|
||||
request *http.Request
|
||||
mainSession *sessions.Session
|
||||
accessSession *sessions.Session
|
||||
refreshSession *sessions.Session
|
||||
accessTokenChunks map[int]*sessions.Session
|
||||
// manager is the SessionManager that created this SessionData
|
||||
manager *SessionManager
|
||||
|
||||
// request is the current HTTP request associated with this session
|
||||
request *http.Request
|
||||
|
||||
// mainSession stores authentication state and basic user info
|
||||
mainSession *sessions.Session
|
||||
|
||||
// accessSession stores the primary access token cookie
|
||||
accessSession *sessions.Session
|
||||
|
||||
// refreshSession stores the primary refresh token cookie
|
||||
refreshSession *sessions.Session
|
||||
|
||||
// accessTokenChunks stores additional chunks of the access token
|
||||
// when it exceeds the maximum cookie size
|
||||
accessTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshTokenChunks stores additional chunks of the refresh token
|
||||
// when it exceeds the maximum cookie size
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
}
|
||||
|
||||
// Save saves all session data
|
||||
// Save persists all session data to cookies in the HTTP response.
|
||||
// It saves the main session, token sessions, and any token chunks,
|
||||
// applying appropriate security options to each cookie. All cookies
|
||||
// are saved with consistent security settings based on the request scheme.
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
@@ -158,7 +214,9 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear clears all session data
|
||||
// Clear removes all session data by expiring all cookies and clearing their values.
|
||||
// This is typically used during logout to ensure all session data is properly cleaned up.
|
||||
// It handles both main session data and any token chunks that may exist.
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
@@ -182,7 +240,9 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
return sd.Save(r, w)
|
||||
}
|
||||
|
||||
// clearTokenChunks clears chunked token sessions
|
||||
// clearTokenChunks removes all session chunks for a given token type.
|
||||
// It expires the cookies and removes all stored values to ensure
|
||||
// no token data remains after logout or token invalidation.
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
@@ -192,18 +252,24 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthenticated returns authentication status
|
||||
// GetAuthenticated returns whether the current session is authenticated.
|
||||
// Returns true if the user has successfully completed OIDC authentication,
|
||||
// false otherwise or if the authentication status cannot be determined.
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
return auth
|
||||
}
|
||||
|
||||
// SetAuthenticated sets authentication status
|
||||
// SetAuthenticated updates the session's authentication status.
|
||||
// This should be called after successful OIDC authentication or during logout.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) {
|
||||
sd.mainSession.Values["authenticated"] = value
|
||||
}
|
||||
|
||||
// GetAccessToken returns the access token
|
||||
// GetAccessToken retrieves the complete access token from the session.
|
||||
// If the token was split into chunks due to size limitations, it will
|
||||
// automatically reassemble the complete token from all chunks.
|
||||
// Returns an empty string if no token is found.
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -228,7 +294,11 @@ func (sd *SessionData) GetAccessToken() string {
|
||||
return strings.Join(chunks, "")
|
||||
}
|
||||
|
||||
// SetAccessToken sets the access token
|
||||
// SetAccessToken stores the access token in the session.
|
||||
// If the token exceeds maxCookieSize, it is automatically split into
|
||||
// multiple cookie chunks to handle large tokens while staying within
|
||||
// browser cookie size limits. Any existing token or chunks are cleared
|
||||
// before setting the new token.
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Clear existing chunks
|
||||
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
|
||||
@@ -249,7 +319,10 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetRefreshToken returns the refresh token
|
||||
// GetRefreshToken retrieves the complete refresh token from the session.
|
||||
// If the token was split into chunks due to size limitations, it will
|
||||
// automatically reassemble the complete token from all chunks.
|
||||
// Returns an empty string if no token is found.
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -274,7 +347,11 @@ func (sd *SessionData) GetRefreshToken() string {
|
||||
return strings.Join(chunks, "")
|
||||
}
|
||||
|
||||
// SetRefreshToken sets the refresh token
|
||||
// SetRefreshToken stores the refresh token in the session.
|
||||
// If the token exceeds maxCookieSize, it is automatically split into
|
||||
// multiple cookie chunks to handle large tokens while staying within
|
||||
// browser cookie size limits. Any existing token or chunks are cleared
|
||||
// before setting the new token.
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Clear existing chunks
|
||||
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
|
||||
@@ -295,7 +372,12 @@ func (sd *SessionData) SetRefreshToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size
|
||||
// splitIntoChunks splits a string into chunks of specified size.
|
||||
// This is used internally to handle large tokens that exceed cookie size limits.
|
||||
// Parameters:
|
||||
// - s: The string to split
|
||||
// - chunkSize: Maximum size of each chunk
|
||||
// Returns an array of string chunks, each no larger than chunkSize.
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
@@ -310,46 +392,65 @@ func splitIntoChunks(s string, chunkSize int) []string {
|
||||
return chunks
|
||||
}
|
||||
|
||||
// GetCSRF returns the CSRF token
|
||||
// GetCSRF retrieves the CSRF token from the session.
|
||||
// This token is used to prevent cross-site request forgery attacks
|
||||
// by ensuring requests originate from the authenticated user.
|
||||
// Returns an empty string if no CSRF token is found.
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRF token
|
||||
// SetCSRF stores a new CSRF token in the session.
|
||||
// This should be called when initiating authentication to generate
|
||||
// a new token for the authentication flow.
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce returns the nonce
|
||||
// GetNonce retrieves the nonce value from the session.
|
||||
// The nonce is used to prevent replay attacks in the OIDC flow
|
||||
// by ensuring the token received matches the authentication request.
|
||||
// Returns an empty string if no nonce is found.
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce sets the nonce
|
||||
// SetNonce stores a new nonce value in the session.
|
||||
// This should be called when initiating authentication to generate
|
||||
// a new nonce for the OIDC authentication flow.
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetEmail returns the user's email
|
||||
// GetEmail retrieves the authenticated user's email address from the session.
|
||||
// The email is typically extracted from the OIDC ID token claims.
|
||||
// Returns an empty string if no email is found.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail sets the user's email
|
||||
// SetEmail stores the user's email address in the session.
|
||||
// This should be called after successful authentication when
|
||||
// processing the OIDC ID token claims.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath returns the original incoming path
|
||||
// GetIncomingPath retrieves the original request path that triggered
|
||||
// the authentication flow. This is used to redirect the user back
|
||||
// to their intended destination after successful authentication.
|
||||
// Returns an empty string if no path was stored.
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath sets the original incoming path
|
||||
// SetIncomingPath stores the original request path that triggered
|
||||
// the authentication flow. This should be called before redirecting
|
||||
// to the OIDC provider to remember where to send the user afterward.
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
|
||||
+140
-29
@@ -5,35 +5,93 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
cookieName = "_raczylo_oidc"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the OIDC middleware
|
||||
// Config holds the configuration for the OIDC middleware.
|
||||
// It provides all necessary settings to configure OpenID Connect authentication
|
||||
// with various providers like Auth0, Logto, or any standard OIDC provider.
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
Scopes []string `json:"scopes"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
// ProviderURL is the base URL of the OIDC provider (required)
|
||||
// Example: https://accounts.google.com
|
||||
ProviderURL string `json:"providerURL"`
|
||||
|
||||
// RevocationURL is the endpoint for revoking tokens (optional)
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
|
||||
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
|
||||
// Example: /oauth2/callback
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
|
||||
// LogoutURL is the path for handling logout requests (optional)
|
||||
// If not provided, it will be set to CallbackURL + "/logout"
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
|
||||
// ClientID is the OAuth 2.0 client identifier (required)
|
||||
ClientID string `json:"clientID"`
|
||||
|
||||
// ClientSecret is the OAuth 2.0 client secret (required)
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
|
||||
// Scopes defines the OAuth 2.0 scopes to request (optional)
|
||||
// Defaults to ["openid", "profile", "email"] if not provided
|
||||
Scopes []string `json:"scopes"`
|
||||
|
||||
// LogLevel sets the logging verbosity (optional)
|
||||
// Valid values: "debug", "info", "error"
|
||||
// Default: "info"
|
||||
LogLevel string `json:"logLevel"`
|
||||
|
||||
// SessionEncryptionKey is used to encrypt session data (required)
|
||||
// Must be a secure random string
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
|
||||
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
|
||||
// Default: false
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
|
||||
// RateLimit sets the maximum number of requests per second (optional)
|
||||
// Default: 100
|
||||
RateLimit int `json:"rateLimit"`
|
||||
|
||||
// ExcludedURLs lists paths that bypass authentication (optional)
|
||||
// Example: ["/health", "/metrics"]
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
|
||||
// AllowedUserDomains restricts access to specific email domains (optional)
|
||||
// Example: ["company.com", "subsidiary.com"]
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
|
||||
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
|
||||
// Example: ["admin", "developer"]
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
HTTPClient *http.Client
|
||||
|
||||
// OIDCEndSessionURL is the provider's end session endpoint (optional)
|
||||
// If not provided, it will be discovered from provider metadata
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
|
||||
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
|
||||
// Default: "/"
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
|
||||
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// CreateConfig creates a new Config with default values
|
||||
// CreateConfig creates a new Config with sensible default values.
|
||||
// Default values are set for optional fields:
|
||||
// - Scopes: ["openid", "profile", "email"]
|
||||
// - LogLevel: "info"
|
||||
// - LogoutURL: CallbackURL + "/logout"
|
||||
// - RateLimit: 100 requests per second
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{}
|
||||
|
||||
@@ -56,14 +114,22 @@ func CreateConfig() *Config {
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate validates the Config
|
||||
// Validate performs validation checks on the Config.
|
||||
// It ensures all required fields are set and have valid values.
|
||||
// Returns an error if any validation check fails.
|
||||
func (c *Config) Validate() error {
|
||||
if c.ProviderURL == "" {
|
||||
return fmt.Errorf("providerURL is required")
|
||||
}
|
||||
if !isValidURL(c.ProviderURL) {
|
||||
return fmt.Errorf("providerURL must be a valid URL")
|
||||
}
|
||||
if c.CallbackURL == "" {
|
||||
return fmt.Errorf("callbackURL is required")
|
||||
}
|
||||
if !strings.HasPrefix(c.CallbackURL, "/") {
|
||||
return fmt.Errorf("callbackURL must start with /")
|
||||
}
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("clientID is required")
|
||||
}
|
||||
@@ -73,17 +139,48 @@ func (c *Config) Validate() error {
|
||||
if c.SessionEncryptionKey == "" {
|
||||
return fmt.Errorf("sessionEncryptionKey is required")
|
||||
}
|
||||
if len(c.SessionEncryptionKey) < 32 {
|
||||
return fmt.Errorf("sessionEncryptionKey must be at least 32 characters long")
|
||||
}
|
||||
if c.RateLimit < 0 {
|
||||
return fmt.Errorf("rateLimit must be non-negative")
|
||||
}
|
||||
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
|
||||
return fmt.Errorf("logLevel must be one of: debug, info, error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger is a simple logger with different levels
|
||||
// isValidURL checks if the provided string is a valid URL
|
||||
func isValidURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme != "" && u.Host != ""
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level is valid
|
||||
func isValidLogLevel(level string) bool {
|
||||
return level == "debug" || level == "info" || level == "error"
|
||||
}
|
||||
|
||||
// Logger provides structured logging capabilities with different severity levels.
|
||||
// It supports error, info, and debug levels with appropriate output streams
|
||||
// and formatting for each level.
|
||||
type Logger struct {
|
||||
// logError handles error-level messages, writing to stderr
|
||||
logError *log.Logger
|
||||
logInfo *log.Logger
|
||||
// logInfo handles informational messages, writing to stdout
|
||||
logInfo *log.Logger
|
||||
// logDebug handles debug-level messages, writing to stdout when debug is enabled
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger
|
||||
// NewLogger creates a new Logger with the specified log level.
|
||||
// The log level determines which messages are output:
|
||||
// - "debug": Outputs all messages (debug, info, error)
|
||||
// - "info": Outputs info and error messages
|
||||
// - "error": Outputs only error messages
|
||||
// Error messages are always written to stderr, while info and debug
|
||||
// messages are written to stdout when enabled.
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
@@ -103,37 +200,51 @@ func NewLogger(logLevel string) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
// Info logs an informational message.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
// Debug logs a debug message.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
// Error logs an error message.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs an info message
|
||||
// Infof logs an informational message using Printf formatting.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message
|
||||
// Debugf logs a debug message using Printf formatting.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error message
|
||||
// Errorf logs an error message using Printf formatting.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// handleError writes an error message to the response and logs it
|
||||
// handleError writes an error message to both the HTTP response and the error log.
|
||||
// It ensures consistent error handling across the middleware by logging the error
|
||||
// and sending an appropriate HTTP response to the client.
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error(message)
|
||||
http.Error(w, message, code)
|
||||
|
||||
Reference in New Issue
Block a user