mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8a6e37f7fc | |||
| bd7eaf6dff | |||
| 3df19e6d90 | |||
| 1910cd6000 | |||
| 46c2f98a15 | |||
| 9e8634bfc0 | |||
| 23e019092a |
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Lukasz Raczylo
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -69,14 +69,15 @@ The middleware supports the following configuration options:
|
||||
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
|
||||
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
|
||||
| | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
|
||||
+10
-2
@@ -2,8 +2,16 @@ package traefikoidc
|
||||
|
||||
import "time"
|
||||
|
||||
// autoCleanupRoutine runs a ticker that calls the provided cleanup function at the specified interval.
|
||||
// It stops when a value is received on the stop channel.
|
||||
// autoCleanupRoutine periodically calls the provided cleanup function.
|
||||
// It starts a ticker with the given interval and executes the cleanup function
|
||||
// on each tick. The routine stops gracefully when a signal is received on the
|
||||
// stop channel. This is typically used for background cleanup tasks like
|
||||
// expiring cache entries.
|
||||
//
|
||||
// Parameters:
|
||||
// - interval: The time duration between cleanup calls.
|
||||
// - stop: A channel used to signal the routine to stop. Receiving any value will terminate the loop.
|
||||
// - cleanup: The function to call periodically for cleanup tasks.
|
||||
func autoCleanupRoutine(interval time.Duration, stop <-chan struct{}, cleanup func()) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
-110
@@ -1,110 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenBlacklist manages a thread-safe list of revoked tokens with expiration.
|
||||
type TokenBlacklist struct {
|
||||
tokens map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new token blacklist instance.
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
tokens: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist with an expiration time.
|
||||
func (b *TokenBlacklist) Add(token string, expiry time.Time) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
// Clean up expired tokens if we're at capacity
|
||||
if len(b.tokens) >= 1000 {
|
||||
now := time.Now()
|
||||
futureThreshold := now.Add(time.Minute)
|
||||
for t, exp := range b.tokens {
|
||||
if now.After(exp) || futureThreshold.After(exp) {
|
||||
delete(b.tokens, t)
|
||||
}
|
||||
}
|
||||
|
||||
// If still at capacity, remove oldest token
|
||||
if len(b.tokens) >= 1000 {
|
||||
var oldestToken string
|
||||
var oldestTime time.Time
|
||||
first := true
|
||||
for t, exp := range b.tokens {
|
||||
if first || exp.Before(oldestTime) {
|
||||
oldestToken = t
|
||||
oldestTime = exp
|
||||
first = false
|
||||
}
|
||||
}
|
||||
if oldestToken != "" {
|
||||
delete(b.tokens, oldestToken)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.tokens[token] = expiry
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is in the blacklist and not expired.
|
||||
func (b *TokenBlacklist) IsBlacklisted(token string) bool {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
|
||||
expiry, exists := b.tokens[token]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// If token is expired, remove it and return false
|
||||
if time.Now().After(expiry) {
|
||||
// Switch to write lock to remove expired token
|
||||
b.mutex.RUnlock()
|
||||
b.mutex.Lock()
|
||||
delete(b.tokens, token)
|
||||
b.mutex.Unlock()
|
||||
b.mutex.RLock()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist.
|
||||
// Also removes tokens that will expire within the next minute to prevent edge cases.
|
||||
func (b *TokenBlacklist) Cleanup() {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
futureThreshold := now.Add(time.Minute)
|
||||
|
||||
for token, expiry := range b.tokens {
|
||||
// Remove tokens that are expired or will expire soon
|
||||
if now.After(expiry) || futureThreshold.After(expiry) {
|
||||
delete(b.tokens, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove removes a token from the blacklist regardless of its expiration.
|
||||
func (b *TokenBlacklist) Remove(token string) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
delete(b.tokens, token)
|
||||
}
|
||||
|
||||
// Count returns the current number of tokens in the blacklist.
|
||||
func (b *TokenBlacklist) Count() int {
|
||||
b.mutex.RLock()
|
||||
defer b.mutex.RUnlock()
|
||||
return len(b.tokens)
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBlacklist_Add(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
|
||||
if !blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be blacklisted, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_IsBlacklisted(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
|
||||
if !blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be blacklisted, but it was not")
|
||||
}
|
||||
|
||||
if blacklist.IsBlacklisted("nonExistentToken") {
|
||||
t.Errorf("Expected non-existent token to not be blacklisted, but it was")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Cleanup(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(-time.Hour) // Expired token
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
blacklist.Cleanup()
|
||||
|
||||
if blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected expired token to be removed after cleanup, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Remove(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token := "testToken"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token, expiry)
|
||||
blacklist.Remove(token)
|
||||
|
||||
if blacklist.IsBlacklisted(token) {
|
||||
t.Errorf("Expected token to be removed, but it was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist_Count(t *testing.T) {
|
||||
blacklist := NewTokenBlacklist()
|
||||
token1 := "token1"
|
||||
token2 := "token2"
|
||||
expiry := time.Now().Add(time.Hour)
|
||||
|
||||
blacklist.Add(token1, expiry)
|
||||
blacklist.Add(token2, expiry)
|
||||
|
||||
if blacklist.Count() != 2 {
|
||||
t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count())
|
||||
}
|
||||
}
|
||||
@@ -46,7 +46,9 @@ type Cache struct {
|
||||
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||
const DefaultMaxSize = 500
|
||||
|
||||
// NewCache creates a new empty cache instance that is ready for use.
|
||||
// NewCache creates a new empty cache instance with default settings.
|
||||
// It initializes the internal maps and list, sets the default maximum size,
|
||||
// and starts the automatic cleanup goroutine.
|
||||
func NewCache() *Cache {
|
||||
c := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
@@ -60,8 +62,12 @@ func NewCache() *Cache {
|
||||
return c
|
||||
}
|
||||
|
||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||
// It moves the item to the most recently used position.
|
||||
// Set adds or updates an item in the cache with the specified key, value, and expiration duration.
|
||||
// If the key already exists, its value and expiration time are updated, and it's moved
|
||||
// to the most recently used position in the LRU list.
|
||||
// If the key does not exist and the cache is full, the least recently used item is evicted
|
||||
// before adding the new item.
|
||||
// The expiration duration is relative to the time Set is called.
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -95,8 +101,11 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.elems[key] = elem
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache if it exists and hasn't expired.
|
||||
// Moving the accessed item to the most recently used position.
|
||||
// Get retrieves an item from the cache by its key.
|
||||
// If the item exists and has not expired, its value and true are returned.
|
||||
// Accessing an item moves it to the most recently used position in the LRU list.
|
||||
// If the item does not exist or has expired, nil and false are returned, and the
|
||||
// expired item is removed from the cache.
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -120,7 +129,9 @@ func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache.
|
||||
// Delete removes an item from the cache by its key.
|
||||
// If the key exists, the corresponding item is removed from the cache storage
|
||||
// and the LRU list.
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -128,8 +139,10 @@ func (c *Cache) Delete(key string) {
|
||||
c.removeItem(key)
|
||||
}
|
||||
|
||||
// Cleanup removes all expired items from the cache. This should be called periodically
|
||||
// to prevent memory bloat from expired entries.
|
||||
// Cleanup iterates through the cache and removes all items that have expired.
|
||||
// An item is considered expired if the current time is after its ExpiresAt timestamp.
|
||||
// This method is called automatically by the auto-cleanup goroutine, but can also
|
||||
// be called manually.
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -143,7 +156,11 @@ func (c *Cache) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used item from the cache.
|
||||
// evictOldest removes the least recently used (oldest) item from the cache.
|
||||
// It first attempts to find and remove an expired item from the front of the LRU list.
|
||||
// If no expired items are found at the front, it removes the absolute oldest item (front of the list).
|
||||
// This method is called internally by Set when the cache reaches its maximum size.
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) evictOldest() {
|
||||
now := time.Now()
|
||||
elem := c.order.Front()
|
||||
@@ -167,7 +184,9 @@ func (c *Cache) evictOldest() {
|
||||
}
|
||||
}
|
||||
|
||||
// removeItem removes an item from both the cache and the LRU tracking structures.
|
||||
// removeItem removes an item specified by the key from the cache's internal storage (items map)
|
||||
// and its corresponding entry from the LRU list (order list and elems map).
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) removeItem(key string) {
|
||||
delete(c.items, key)
|
||||
if elem, ok := c.elems[key]; ok {
|
||||
@@ -176,12 +195,15 @@ func (c *Cache) removeItem(key string) {
|
||||
}
|
||||
}
|
||||
|
||||
// startAutoCleanup initiates a goroutine that periodically cleans up expired cache items.
|
||||
// startAutoCleanup starts the background goroutine that automatically calls the Cleanup method
|
||||
// at the interval specified by c.autoCleanupInterval.
|
||||
// It uses the autoCleanupRoutine helper function.
|
||||
func (c *Cache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
// Close terminates the auto cleanup goroutine.
|
||||
// Close stops the automatic cleanup goroutine associated with this cache instance.
|
||||
// It should be called when the cache is no longer needed to prevent resource leaks.
|
||||
func (c *Cache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
+130
-182
@@ -15,10 +15,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce
|
||||
// for use in the OIDC authentication flow. The nonce is used to
|
||||
// prevent replay attacks by ensuring the token received matches
|
||||
// the authentication request.
|
||||
// generateNonce creates a cryptographically secure random string suitable for use as an OIDC nonce.
|
||||
// The nonce is used during the authentication flow to mitigate replay attacks by associating
|
||||
// the ID token with the specific authentication request.
|
||||
// It generates 32 random bytes and encodes them using base64 URL encoding.
|
||||
//
|
||||
// Returns:
|
||||
// - A base64 URL encoded random string (nonce).
|
||||
// - An error if the random byte generation fails.
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
@@ -28,9 +32,13 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// generateCodeVerifier creates a cryptographically secure random string
|
||||
// for use as a PKCE code verifier. The code verifier must be between 43 and 128
|
||||
// characters long, per the PKCE spec (RFC 7636).
|
||||
// generateCodeVerifier creates a cryptographically secure random string suitable for use as a PKCE code verifier.
|
||||
// According to RFC 7636, the verifier should be a high-entropy string between 43 and 128 characters long.
|
||||
// This function generates 32 random bytes, resulting in a 43-character base64 URL encoded string.
|
||||
//
|
||||
// Returns:
|
||||
// - A base64 URL encoded random string (code verifier).
|
||||
// - An error if the random byte generation fails.
|
||||
func generateCodeVerifier() (string, error) {
|
||||
// Using 32 bytes (256 bits) will produce a 43 character base64url string
|
||||
verifierBytes := make([]byte, 32)
|
||||
@@ -41,8 +49,15 @@ func generateCodeVerifier() (string, error) {
|
||||
return base64.RawURLEncoding.EncodeToString(verifierBytes), nil
|
||||
}
|
||||
|
||||
// deriveCodeChallenge creates a code challenge from a code verifier
|
||||
// using the SHA-256 method as specified in the PKCE standard (RFC 7636).
|
||||
// deriveCodeChallenge computes the PKCE code challenge from a given code verifier.
|
||||
// It uses the S256 challenge method (SHA-256 hash followed by base64 URL encoding)
|
||||
// as defined in RFC 7636.
|
||||
//
|
||||
// Parameters:
|
||||
// - codeVerifier: The high-entropy string generated by generateCodeVerifier.
|
||||
//
|
||||
// Returns:
|
||||
// - The base64 URL encoded SHA-256 hash of the code verifier (code challenge).
|
||||
func deriveCodeChallenge(codeVerifier string) string {
|
||||
// Calculate SHA-256 hash of the code verifier
|
||||
hasher := sha256.New()
|
||||
@@ -73,15 +88,23 @@ type TokenResponse struct {
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
|
||||
// It supports both authorization code and refresh token grant types.
|
||||
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider's token endpoint.
|
||||
// It handles both the "authorization_code" grant type (exchanging an authorization code for tokens)
|
||||
// and the "refresh_token" grant type (using a refresh token to obtain new tokens).
|
||||
// It includes necessary parameters like client credentials and handles PKCE verification if applicable.
|
||||
// The function follows redirects and handles potential errors during the exchange.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
|
||||
// - codeOrToken: Either the authorization code or refresh token
|
||||
// - redirectURL: The callback URL for authorization code grant
|
||||
// - codeVerifier: Optional PKCE code verifier for authorization code grant
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
// - ctx: The context for the outgoing HTTP request.
|
||||
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token").
|
||||
// - codeOrToken: The authorization code (for "authorization_code" grant) or the refresh token (for "refresh_token" grant).
|
||||
// - redirectURL: The redirect URI that was used in the initial authorization request (required for "authorization_code" grant).
|
||||
// - codeVerifier: The PKCE code verifier (required for "authorization_code" grant if PKCE was used).
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the obtained tokens (ID, access, refresh).
|
||||
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant).
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
@@ -140,8 +163,16 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
|
||||
// This is used to refresh access tokens before they expire.
|
||||
// getNewTokenWithRefreshToken uses a refresh token to obtain a new set of tokens (ID, access, refresh)
|
||||
// from the OIDC provider's token endpoint. It wraps the exchangeTokens function with the
|
||||
// "refresh_token" grant type.
|
||||
//
|
||||
// Parameters:
|
||||
// - refreshToken: The refresh token previously obtained during authentication or a prior refresh.
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the newly obtained tokens.
|
||||
// - An error if the refresh operation fails.
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
|
||||
@@ -153,151 +184,17 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// handleExpiredToken manages token expiration by clearing the session
|
||||
// and initiating a new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear authentication data but preserve CSRF state
|
||||
session.SetAuthenticated(false)
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
|
||||
// Save the cleared session state
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save cleared session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback processes the authentication callback from the OIDC provider.
|
||||
// It validates the callback parameters, exchanges the authorization code for
|
||||
// tokens, verifies the tokens, and establishes the user's session.
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Check for errors in the callback
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF state
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
http.Error(rw, "No code in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the code verifier from the session for PKCE flow
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify tokens and claims
|
||||
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify nonce to prevent replay attacks
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate user's email domain
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Update session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to original path or root
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// extractClaims parses a JWT token and extracts its claims.
|
||||
// It handles base64url decoding and JSON parsing of the token payload.
|
||||
// extractClaims decodes the payload (claims set) part of a JWT string.
|
||||
// It splits the JWT into its three parts, base64 URL decodes the second part (payload),
|
||||
// and unmarshals the resulting JSON into a map.
|
||||
// Note: This function does *not* validate the token's signature or claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT string.
|
||||
//
|
||||
// Returns:
|
||||
// - A map representing the JSON claims extracted from the token payload.
|
||||
// - An error if the token format is invalid, decoding fails, or JSON unmarshaling fails.
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -325,21 +222,36 @@ type TokenCache struct {
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewTokenCache creates a new TokenCache instance.
|
||||
// NewTokenCache creates and initializes a new TokenCache.
|
||||
// It internally creates a new generic Cache instance for storage.
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
cache: NewCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a token's claims in the cache with an expiration time.
|
||||
// Set stores the claims associated with a specific token string in the cache.
|
||||
// It prefixes the token string to avoid potential collisions with other cache types
|
||||
// and sets the provided expiration duration.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string (used as the key).
|
||||
// - claims: The map of claims associated with the token.
|
||||
// - expiration: The duration for which the cache entry should be valid.
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
}
|
||||
|
||||
// Get retrieves a token's claims from the cache.
|
||||
// Returns the claims and a boolean indicating if the token was found.
|
||||
// Get retrieves the cached claims for a given token string.
|
||||
// It prefixes the token string before querying the underlying cache.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to look up.
|
||||
//
|
||||
// Returns:
|
||||
// - The cached claims map if found and valid.
|
||||
// - A boolean indicating whether the token was found in the cache (true if found, false otherwise).
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
@@ -350,20 +262,34 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// Delete removes a token from the cache.
|
||||
// Delete removes the cached entry for a specific token string.
|
||||
// It prefixes the token string before calling the underlying cache's Delete method.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The raw token string to remove from the cache.
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the cache.
|
||||
// Cleanup triggers the cleanup process for the underlying generic cache,
|
||||
// removing expired token entries.
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.cache.Cleanup()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges an authorization code for tokens.
|
||||
// It handles PKCE (Proof Key for Code Exchange) based on middleware configuration.
|
||||
// The code verifier is only included in the token request if PKCE is enabled.
|
||||
// exchangeCodeForToken is a convenience function that wraps exchangeTokens specifically
|
||||
// for the "authorization_code" grant type. It handles the conditional inclusion of the
|
||||
// PKCE code verifier based on the middleware's configuration (t.enablePKCE).
|
||||
//
|
||||
// Parameters:
|
||||
// - code: The authorization code received from the OIDC provider.
|
||||
// - redirectURL: The redirect URI used in the initial authorization request.
|
||||
// - codeVerifier: The PKCE code verifier stored in the session (if PKCE is enabled).
|
||||
//
|
||||
// Returns:
|
||||
// - A TokenResponse containing the obtained tokens.
|
||||
// - An error if the code exchange fails.
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -380,8 +306,15 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string, code
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// createStringMap creates a map from a slice of strings.
|
||||
// Used for efficient lookups in allowed domains and roles.
|
||||
// createStringMap converts a slice of strings into a map[string]struct{} (a set).
|
||||
// This is useful for creating efficient lookups (O(1) average time complexity)
|
||||
// for checking the presence of items like allowed domains, roles, or groups.
|
||||
//
|
||||
// Parameters:
|
||||
// - keys: A slice of strings to be added to the set.
|
||||
//
|
||||
// Returns:
|
||||
// - A map where the keys are the strings from the input slice and the values are empty structs.
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
@@ -390,9 +323,17 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout manages the OIDC logout process.
|
||||
// It clears the session and redirects either to the OIDC provider's
|
||||
// end session endpoint (if available) or to the configured post-logout URL.
|
||||
// handleLogout processes requests to the configured logout path.
|
||||
// It performs the following steps:
|
||||
// 1. Retrieves the current user session.
|
||||
// 2. Gets the access token (ID token hint) from the session.
|
||||
// 3. Clears all authentication-related data from the session cookies.
|
||||
// 4. Determines the final post-logout redirect URI.
|
||||
// 5. If an OIDC end_session_endpoint is configured and an ID token hint is available,
|
||||
// it builds the OIDC logout URL and redirects the user agent to the provider for logout.
|
||||
// 6. Otherwise, it redirects the user agent directly to the post-logout redirect URI.
|
||||
//
|
||||
// It handles potential errors during session retrieval or clearing.
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
@@ -434,11 +375,18 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
|
||||
// BuildLogoutURL constructs the URL for redirecting the user agent to the OIDC provider's
|
||||
// end_session_endpoint, including the required id_token_hint and optional
|
||||
// post_logout_redirect_uri parameters as query arguments.
|
||||
//
|
||||
// Parameters:
|
||||
// - endSessionURL: The OIDC provider's end session endpoint
|
||||
// - idToken: The ID token to be invalidated
|
||||
// - postLogoutRedirectURI: Where to redirect after logout completes
|
||||
// - endSessionURL: The URL of the OIDC provider's end session endpoint.
|
||||
// - idToken: The ID token previously issued to the user (used as id_token_hint).
|
||||
// - postLogoutRedirectURI: The optional URI where the provider should redirect the user agent after logout.
|
||||
//
|
||||
// Returns:
|
||||
// - The fully constructed logout URL string.
|
||||
// - An error if the provided endSessionURL is invalid.
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
|
||||
+6
-166
@@ -7,172 +7,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBlacklistSizeLimit(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add tokens up to maxSize
|
||||
for i := 0; i < 1000; i++ {
|
||||
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
|
||||
}
|
||||
|
||||
// Verify size is at max
|
||||
if tb.Count() != 1000 {
|
||||
t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count())
|
||||
}
|
||||
|
||||
// Add one more token, should trigger cleanup/eviction
|
||||
tb.Add("newtoken", time.Now().Add(time.Hour))
|
||||
|
||||
// Size should still be at max
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add some expired tokens
|
||||
for i := 0; i < 500; i++ {
|
||||
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
|
||||
}
|
||||
|
||||
// Add some valid tokens
|
||||
for i := 0; i < 500; i++ {
|
||||
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
|
||||
}
|
||||
|
||||
// Force cleanup
|
||||
tb.Cleanup()
|
||||
|
||||
// Only valid tokens should remain
|
||||
if tb.Count() != 500 {
|
||||
t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count())
|
||||
}
|
||||
|
||||
// Verify only valid tokens remain
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
for token, expiry := range tb.tokens {
|
||||
if time.Now().After(expiry) {
|
||||
t.Errorf("Found expired token after cleanup: %s", token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistOldestEviction(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
|
||||
// Add tokens at capacity with different expiration times
|
||||
baseTime := time.Now()
|
||||
oldestToken := "oldest"
|
||||
|
||||
// Add oldest token first
|
||||
tb.Add(oldestToken, baseTime.Add(time.Hour))
|
||||
|
||||
// Fill up to capacity with newer tokens
|
||||
for i := 0; i < 999; i++ {
|
||||
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
|
||||
}
|
||||
|
||||
// Add a new token that should evict the oldest
|
||||
newToken := "newest"
|
||||
tb.Add(newToken, baseTime.Add(time.Hour*3))
|
||||
|
||||
// Verify oldest token was evicted
|
||||
if tb.IsBlacklisted(oldestToken) {
|
||||
t.Error("Oldest token should have been evicted")
|
||||
}
|
||||
|
||||
// Verify newest token is present
|
||||
if !tb.IsBlacklisted(newToken) {
|
||||
t.Error("Newest token should be present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklistMemoryUsage(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
iterations := 10000
|
||||
|
||||
// Force initial GC
|
||||
runtime.GC()
|
||||
|
||||
// Record initial memory stats
|
||||
var m1, m2 runtime.MemStats
|
||||
runtime.ReadMemStats(&m1)
|
||||
|
||||
// Simulate heavy usage
|
||||
for i := 0; i < iterations; i++ {
|
||||
// Add new token
|
||||
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
|
||||
|
||||
// Periodically check blacklisted status
|
||||
if i%100 == 0 {
|
||||
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
|
||||
}
|
||||
|
||||
// Periodically cleanup
|
||||
if i%1000 == 0 {
|
||||
tb.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// Force GC and wait for it to complete
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
runtime.ReadMemStats(&m2)
|
||||
|
||||
// Check memory growth (using HeapAlloc for more accurate measurement)
|
||||
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
|
||||
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
|
||||
|
||||
if memoryGrowth > maxAllowedGrowth {
|
||||
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
|
||||
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
|
||||
}
|
||||
|
||||
// Verify size stayed within limits
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
|
||||
tb := NewTokenBlacklist()
|
||||
iterations := 1000
|
||||
concurrency := 10
|
||||
done := make(chan bool)
|
||||
|
||||
// Start multiple goroutines performing operations
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(id int) {
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Add tokens
|
||||
token := fmt.Sprintf("token%d-%d", id, j)
|
||||
tb.Add(token, time.Now().Add(time.Hour))
|
||||
|
||||
// Check blacklist status
|
||||
tb.IsBlacklisted(token)
|
||||
|
||||
// Periodic cleanup
|
||||
if j%100 == 0 {
|
||||
tb.Cleanup()
|
||||
}
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify size constraints were maintained
|
||||
if tb.Count() > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count())
|
||||
}
|
||||
}
|
||||
// Removed tests related to the old TokenBlacklist implementation:
|
||||
// - TestTokenBlacklistSizeLimit
|
||||
// - TestTokenBlacklistExpiredCleanup
|
||||
// - TestTokenBlacklistOldestEviction
|
||||
// - TestTokenBlacklistMemoryUsage
|
||||
// - TestConcurrentTokenBlacklistOperations
|
||||
|
||||
func TestTokenCacheMemoryUsage(t *testing.T) {
|
||||
tc := NewTokenCache()
|
||||
|
||||
@@ -45,6 +45,21 @@ type JWKCacheInterface interface {
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
// GetJWKS retrieves the JSON Web Key Set (JWKS) from the cache or fetches it from the provider.
|
||||
// It first checks if a valid, non-expired JWKS is present in the cache. If so, it returns the cached version.
|
||||
// Otherwise, it attempts to fetch the JWKS from the specified jwksURL using the provided httpClient.
|
||||
// If the fetch is successful, the JWKS is stored in the cache with an expiration time based on CacheLifetime
|
||||
// (defaulting to 1 hour if not set) and returned.
|
||||
// This method uses double-checked locking to minimize contention when the cache needs refreshing.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request if fetching is required.
|
||||
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
|
||||
// - httpClient: The HTTP client to use for fetching the JWKS.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the JWKSet containing the keys.
|
||||
// - An error if fetching fails or the response cannot be decoded.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
@@ -74,6 +89,8 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// Cleanup removes the cached JWKS if it has expired.
|
||||
// This is intended to be called periodically to ensure stale JWKS data is cleared.
|
||||
func (c *JWKCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -84,6 +101,17 @@ func (c *JWKCache) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// fetchJWKS retrieves the JSON Web Key Set (JWKS) from the specified URL.
|
||||
// It uses the provided context and HTTP client to make the request.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the HTTP request.
|
||||
// - jwksURL: The URL of the OIDC provider's JWKS endpoint.
|
||||
// - httpClient: The HTTP client to use for the request.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the fetched JWKSet.
|
||||
// - An error if the request fails, the status code is not OK, or the response body cannot be decoded.
|
||||
func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Create a request with context to enforce timeout
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
|
||||
@@ -109,6 +137,16 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JWK (JSON Web Key) object into PEM (Privacy-Enhanced Mail) format.
|
||||
// It selects the appropriate conversion function based on the JWK's key type ("kty").
|
||||
// Currently supports "RSA" and "EC" key types.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the JWK object to convert.
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the public key in PEM format.
|
||||
// - An error if the key type is unsupported or conversion fails.
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
@@ -124,6 +162,17 @@ var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JWK into PEM format.
|
||||
// It decodes the modulus (n) and exponent (e) from base64 URL encoding,
|
||||
// constructs an rsa.PublicKey, marshals it into PKIX format, and then
|
||||
// encodes it as a PEM block.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the RSA JWK object (must have "kty": "RSA").
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the RSA public key in PEM format.
|
||||
// - An error if decoding parameters fails or key marshaling fails.
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
@@ -155,6 +204,18 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return pubKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC (Elliptic Curve) JWK into PEM format.
|
||||
// It decodes the X and Y coordinates from base64 URL encoding, determines the
|
||||
// elliptic curve based on the "crv" parameter (P-256, P-384, P-521),
|
||||
// constructs an ecdsa.PublicKey, marshals it into PKIX format, and then
|
||||
// encodes it as a PEM block.
|
||||
//
|
||||
// Parameters:
|
||||
// - jwk: A pointer to the EC JWK object (must have "kty": "EC").
|
||||
//
|
||||
// Returns:
|
||||
// - A byte slice containing the EC public key in PEM format.
|
||||
// - An error if decoding parameters fails, the curve is unsupported, or key marshaling fails.
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
|
||||
@@ -15,9 +15,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var replayCacheMu sync.Mutex
|
||||
var replayCache = make(map[string]time.Time)
|
||||
var (
|
||||
replayCacheMu sync.Mutex
|
||||
replayCache = make(map[string]time.Time)
|
||||
)
|
||||
|
||||
// cleanupReplayCache iterates through the replay cache and removes entries
|
||||
// whose expiration time is before the current time. This function should be
|
||||
// called periodically to prevent the cache from growing indefinitely.
|
||||
// It acquires a mutex to ensure thread safety during cleanup.
|
||||
func cleanupReplayCache() {
|
||||
now := time.Now()
|
||||
for token, expiry := range replayCache {
|
||||
@@ -27,8 +33,16 @@ func cleanupReplayCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// ClockSkewTolerance is configurable to adjust time-based validations.
|
||||
var ClockSkewTolerance = 2 * time.Minute
|
||||
// ClockSkewToleranceFuture defines the tolerance for future-based claims like 'exp'.
|
||||
// Allows for more leniency with expiration checks.
|
||||
var ClockSkewToleranceFuture = 2 * time.Minute
|
||||
|
||||
// ClockSkewTolerancePast defines the tolerance for past-based claims like 'iat' and 'nbf'.
|
||||
// A smaller tolerance is typically used here to prevent accepting tokens issued too far in the future.
|
||||
var (
|
||||
ClockSkewTolerancePast = 10 * time.Second
|
||||
ClockSkewTolerance = 2 * time.Minute
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
type JWT struct {
|
||||
@@ -38,7 +52,18 @@ type JWT struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct.
|
||||
// parseJWT decodes a raw JWT string into its constituent parts: header, claims, and signature.
|
||||
// It splits the token string by '.', decodes each part using base64 URL decoding,
|
||||
// and unmarshals the header and claims JSON into maps. The raw signature bytes are stored.
|
||||
// It performs basic format validation (expecting 3 parts).
|
||||
// Note: This function does *not* validate the signature or the claims.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw JWT string.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a JWT struct containing the decoded parts.
|
||||
// - An error if the token format is invalid or decoding/unmarshaling fails.
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -74,8 +99,24 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// Verify validates the standard JWT claims as defined in RFC 7519.
|
||||
// Verify performs standard claim validation on the JWT according to RFC 7519.
|
||||
// It checks the following:
|
||||
// - Algorithm ('alg') is supported.
|
||||
// - Issuer ('iss') matches the expected issuerURL.
|
||||
// - Audience ('aud') contains the expected clientID.
|
||||
// - Expiration time ('exp') is in the future (within tolerance).
|
||||
// - Issued at time ('iat') is in the past (within tolerance).
|
||||
// - Not before time ('nbf'), if present, is in the past (within tolerance).
|
||||
// - Subject ('sub') claim exists and is not empty.
|
||||
// - JWT ID ('jti'), if present, is checked against a replay cache to prevent token reuse.
|
||||
//
|
||||
// Parameters:
|
||||
// - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com").
|
||||
// - clientID: The expected audience value (the client ID of this application).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if all standard claims are valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
// Validate algorithm to prevent algorithm switching attacks
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
@@ -164,6 +205,17 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience checks if the expected audience is present in the token's 'aud' claim.
|
||||
// The 'aud' claim can be a single string or an array of strings.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenAudience: The 'aud' claim value extracted from the token (can be string or []interface{}).
|
||||
// - expectedAudience: The audience value expected for this application (client ID).
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the expected audience is found.
|
||||
// - An error if the claim type is invalid or the expected audience is not present.
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
@@ -187,6 +239,15 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer checks if the token's 'iss' claim matches the expected issuer URL.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenIssuer: The 'iss' claim value from the token.
|
||||
// - expectedIssuer: The expected issuer URL configured for the OIDC provider.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the issuers match.
|
||||
// - An error if the issuers do not match.
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer (token: %s, expected: %s)", tokenIssuer, expectedIssuer)
|
||||
@@ -194,54 +255,76 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyTimeConstraint is a generic function to verify time-based claims
|
||||
// verifyTimeConstraint checks time-based claims ('exp', 'iat', 'nbf') against the current time,
|
||||
// allowing for configurable clock skew. It uses different tolerances for past and future checks.
|
||||
//
|
||||
// Parameters:
|
||||
// - unixTime: The timestamp value from the claim (as a float64 Unix time).
|
||||
// - claimName: The name of the claim being verified ("exp", "iat", "nbf").
|
||||
// - future: A boolean indicating the direction of the check (true for 'exp', false for 'iat'/'nbf').
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the time constraint is met within the allowed tolerance.
|
||||
// - An error describing the failure (e.g., "token has expired", "token used before issued").
|
||||
func verifyTimeConstraint(unixTime float64, claimName string, future bool) error {
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now().Truncate(time.Second)
|
||||
now := time.Now() // Use current time without truncation
|
||||
|
||||
// For expiration (future=true), we add skew to now (making now later)
|
||||
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
|
||||
skewDirection := 1
|
||||
if !future {
|
||||
skewDirection = -1
|
||||
}
|
||||
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
|
||||
|
||||
if claimTime.Equal(now) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For expiration: if skewedNow (later) is after expiration, token expired
|
||||
// For iat/nbf: if skewedNow (earlier) is before claim time, token not yet valid
|
||||
if (future && skewedNow.After(claimTime)) || (!future && skewedNow.Before(claimTime)) {
|
||||
var reason string
|
||||
if future {
|
||||
reason = "has expired"
|
||||
} else {
|
||||
var err error
|
||||
if future { // 'exp' check
|
||||
// Token is expired if Now is after (ClaimTime + FutureTolerance)
|
||||
allowedExpiry := claimTime.Add(ClockSkewToleranceFuture)
|
||||
if now.After(allowedExpiry) {
|
||||
err = fmt.Errorf("token has expired (exp: %v, now: %v, allowed_until: %v)", claimTime.UTC(), now.UTC(), allowedExpiry.UTC())
|
||||
}
|
||||
} else { // 'iat' or 'nbf' check
|
||||
// Token is invalid if Now is before (ClaimTime - PastTolerance)
|
||||
allowedStart := claimTime.Add(-ClockSkewTolerancePast)
|
||||
if now.Before(allowedStart) {
|
||||
reason := "not yet valid"
|
||||
if claimName == "iat" {
|
||||
reason = "used before issued"
|
||||
} else {
|
||||
reason = "not yet valid"
|
||||
}
|
||||
err = fmt.Errorf("token %s (%s: %v, now: %v, allowed_from: %v)", reason, claimName, claimTime.UTC(), now.UTC(), allowedStart.UTC())
|
||||
}
|
||||
return fmt.Errorf("token %s (%s: %v, now: %v)", reason, claimName, claimTime.UTC(), now.UTC())
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// verifyExpiration checks the 'exp' (Expiration Time) claim.
|
||||
// It calls verifyTimeConstraint with future=true.
|
||||
func verifyExpiration(expiration float64) error {
|
||||
return verifyTimeConstraint(expiration, "exp", true)
|
||||
}
|
||||
|
||||
// verifyIssuedAt checks the 'iat' (Issued At) claim.
|
||||
// It calls verifyTimeConstraint with future=false.
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
return verifyTimeConstraint(issuedAt, "iat", false)
|
||||
}
|
||||
|
||||
// verifyNotBefore checks the 'nbf' (Not Before) claim.
|
||||
// It calls verifyTimeConstraint with future=false.
|
||||
func verifyNotBefore(notBefore float64) error {
|
||||
return verifyTimeConstraint(notBefore, "nbf", false)
|
||||
}
|
||||
|
||||
// verifySignature validates the JWT's signature using the provided public key.
|
||||
// It parses the public key from PEM format, selects the appropriate hashing algorithm
|
||||
// based on the 'alg' parameter (SHA256/384/512), hashes the token's signing input
|
||||
// (header + "." + payload), and then verifies the signature against the hash using
|
||||
// the corresponding RSA (PKCS1v15 or PSS) or ECDSA verification method.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenString: The raw, complete JWT string.
|
||||
// - publicKeyPEM: The public key corresponding to the private key used for signing, in PEM format.
|
||||
// - alg: The algorithm specified in the JWT header (e.g., "RS256", "ES384").
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the signature is valid.
|
||||
// - An error if the token format is invalid, decoding fails, key parsing fails,
|
||||
// the algorithm is unsupported, or the signature verification fails.
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
|
||||
+657
-70
@@ -101,29 +101,47 @@ func (ts *TestSuite) Setup() {
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
revocationURL: "https://revocation-endpoint.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
||||
tokenCache: NewTokenCache(),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}},
|
||||
httpClient: &http.Client{},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: ts.sessionManager,
|
||||
// Explicitly set paths as New() is bypassed
|
||||
redirURLPath: "/callback", // Assume default callback path for tests
|
||||
logoutURLPath: "/callback/logout", // Assume default logout path for tests
|
||||
tokenURL: "https://test-issuer.com/token", // Explicitly set for refresh tests
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: ts.sessionManager,
|
||||
}
|
||||
close(ts.tOidc.initComplete)
|
||||
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
|
||||
// ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc // Removed
|
||||
ts.tOidc.tokenVerifier = ts.tOidc
|
||||
ts.tOidc.jwtVerifier = ts.tOidc
|
||||
// Set default mock exchanger
|
||||
ts.tOidc.tokenExchanger = &MockTokenExchanger{
|
||||
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
// Default mock behavior for code exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token, // Use the valid token from setup
|
||||
AccessToken: ts.token,
|
||||
RefreshToken: "default-refresh-token",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
},
|
||||
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
||||
// Default mock behavior for refresh (can be overridden in tests)
|
||||
return nil, fmt.Errorf("default mock: refresh not expected")
|
||||
},
|
||||
RevokeTokenFunc: func(token, tokenType string) error {
|
||||
// Default mock behavior for revoke
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions used by TraefikOidc
|
||||
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
}
|
||||
// Helper function exchangeCodeForTokenFunc removed as it's unused after refactoring to TokenExchanger interface.
|
||||
|
||||
// MockJWKCache implements JWKCacheInterface
|
||||
type MockJWKCache struct {
|
||||
@@ -141,6 +159,34 @@ func (m *MockJWKCache) Cleanup() {
|
||||
m.Err = nil
|
||||
}
|
||||
|
||||
// MockTokenExchanger implements TokenExchanger for testing
|
||||
type MockTokenExchanger struct {
|
||||
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
||||
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
||||
RevokeTokenFunc func(token, tokenType string) error
|
||||
}
|
||||
|
||||
func (m *MockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
if m.ExchangeCodeFunc != nil {
|
||||
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
|
||||
}
|
||||
return nil, fmt.Errorf("ExchangeCodeFunc not implemented in mock")
|
||||
}
|
||||
|
||||
func (m *MockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
if m.RefreshTokenFunc != nil {
|
||||
return m.RefreshTokenFunc(refreshToken)
|
||||
}
|
||||
return nil, fmt.Errorf("RefreshTokenFunc not implemented in mock")
|
||||
}
|
||||
|
||||
func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
if m.RevokeTokenFunc != nil {
|
||||
return m.RevokeTokenFunc(token, tokenType)
|
||||
}
|
||||
return fmt.Errorf("RevokeTokenFunc not implemented in mock")
|
||||
}
|
||||
|
||||
// Helper function to create a JWT token
|
||||
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
||||
header := map[string]interface{}{
|
||||
@@ -228,13 +274,14 @@ func TestVerifyToken(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset token blacklist and cache for each test
|
||||
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
|
||||
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
||||
ts.tOidc.tokenCache = NewTokenCache()
|
||||
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
||||
|
||||
// Set up the test case
|
||||
if tc.blacklist {
|
||||
ts.tOidc.tokenBlacklist.Add(tc.token, time.Now().Add(1*time.Hour))
|
||||
// Use Set with a duration. Value 'true' is arbitrary.
|
||||
ts.tOidc.tokenBlacklist.Set(tc.token, true, 1*time.Hour)
|
||||
}
|
||||
|
||||
if tc.rateLimit {
|
||||
@@ -282,13 +329,54 @@ func TestServeHTTP(t *testing.T) {
|
||||
ts.tOidc.next = nextHandler
|
||||
ts.tOidc.name = "test"
|
||||
|
||||
// Helper to create an expired token
|
||||
createExpiredToken := func() string {
|
||||
exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago
|
||||
iat := time.Now().Add(-2 * time.Hour).Unix()
|
||||
nbf := time.Now().Add(-2 * time.Hour).Unix()
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"nonce": "test-nonce-expired", // Different nonce for clarity
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
return expiredToken
|
||||
}
|
||||
|
||||
// Helper to create a new valid token (simulating refresh)
|
||||
createNewValidToken := func() string {
|
||||
exp := time.Now().Add(1 * time.Hour).Unix() // Valid for 1 hour
|
||||
iat := time.Now().Unix()
|
||||
nbf := time.Now().Unix()
|
||||
newToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
// "nonce": "test-nonce-new", // Nonce is typically not included/validated in refreshed tokens
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
return newToken
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestPath string
|
||||
sessionValues map[interface{}]interface{}
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
setupSession func(*SessionData)
|
||||
name string
|
||||
requestPath string
|
||||
sessionValues map[interface{}]interface{}
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
setupSession func(*SessionData)
|
||||
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
|
||||
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks
|
||||
requestHeaders map[string]string // Added for setting headers like Accept
|
||||
}{
|
||||
{
|
||||
name: "Excluded URL",
|
||||
@@ -297,70 +385,379 @@ func TestServeHTTP(t *testing.T) {
|
||||
expectedBody: "OK",
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated request to protected URL",
|
||||
requestPath: "/protected",
|
||||
expectedStatus: http.StatusFound,
|
||||
name: "Unauthenticated request (no refresh token) to protected URL",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
// Ensure no tokens are set
|
||||
session.SetAuthenticated(false)
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
},
|
||||
expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token
|
||||
},
|
||||
{
|
||||
name: "Authenticated request to protected URL",
|
||||
name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(false) // Not authenticated
|
||||
session.SetAccessToken("") // No access token
|
||||
session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
if refreshToken != "valid-refresh-token-for-unauth-test" {
|
||||
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
|
||||
}
|
||||
// Simulate successful refresh
|
||||
newToken := createNewValidToken() // Use helper from TestServeHTTP
|
||||
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusOK, // Expect OK after successful refresh
|
||||
expectedBody: "OK",
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(false) // Not authenticated
|
||||
session.SetAccessToken("") // No access token
|
||||
session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
// Simulate failed refresh
|
||||
return nil, fmt.Errorf("mock error: refresh token invalid")
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh
|
||||
},
|
||||
{
|
||||
name: "Authenticated request to protected URL (Valid Token)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ts.token)
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
session.SetAccessToken(freshToken)
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: "OK",
|
||||
},
|
||||
// This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist
|
||||
{
|
||||
name: "Authenticated request with expired token and successful refresh",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
// NOTE: isUserAuthenticated now returns authenticated=false if access token is expired,
|
||||
// even if session.SetAuthenticated(true) was called.
|
||||
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(createExpiredToken()) // Set expired token
|
||||
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
if refreshToken != "valid-refresh-token" {
|
||||
return nil, fmt.Errorf("mock error: expected 'valid-refresh-token', got '%s'", refreshToken)
|
||||
}
|
||||
// Simulate successful refresh
|
||||
newToken := createNewValidToken()
|
||||
return &TokenResponse{
|
||||
IDToken: newToken, // Return new valid token
|
||||
AccessToken: newToken, // Often the same as ID token in tests
|
||||
RefreshToken: "new-refresh-token",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusOK, // Expect success after refresh
|
||||
expectedBody: "OK",
|
||||
assertSessionAfterRequest: func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) {
|
||||
// Create a new request to read the cookies set by the response recorder
|
||||
reqForCookieRead := httptest.NewRequest("GET", "/protected", nil)
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
reqForCookieRead.AddCookie(cookie)
|
||||
}
|
||||
// Get session based on response cookies
|
||||
session, err := sessionManager.GetSession(reqForCookieRead)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session after request: %v", err)
|
||||
}
|
||||
// Assert new tokens are in the session
|
||||
if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() {
|
||||
t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one")
|
||||
}
|
||||
if session.GetRefreshToken() != "new-refresh-token" {
|
||||
t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken())
|
||||
}
|
||||
// Also check authenticated flag is now true
|
||||
if !session.GetAuthenticated() {
|
||||
t.Errorf("Expected session to be marked authenticated after successful refresh")
|
||||
}
|
||||
},
|
||||
},
|
||||
// This test case remains valid as the logic should still return 401 for API clients on refresh failure
|
||||
{
|
||||
name: "Logout URL",
|
||||
requestPath: "/logout",
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ts.token)
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
session.SetAccessToken(freshToken)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedStatus: http.StatusFound, // Expect redirect after logout
|
||||
expectedBody: "",
|
||||
// No specific session assertion needed for logout redirect itself
|
||||
},
|
||||
{
|
||||
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(createExpiredToken()) // Expired access token
|
||||
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
// Simulate failed refresh
|
||||
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
|
||||
}
|
||||
},
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt
|
||||
expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`,
|
||||
},
|
||||
// This test case remains valid as the logic should still redirect browser clients on refresh failure
|
||||
{
|
||||
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(createExpiredToken()) // Expired access token
|
||||
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
// Simulate failed refresh
|
||||
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
|
||||
}
|
||||
},
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "text/html", // Browser client
|
||||
},
|
||||
expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt
|
||||
},
|
||||
// This test case remains valid as proactive refresh should still be attempted
|
||||
{
|
||||
name: "Authenticated request with token nearing expiry (needs refresh)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
// Create token expiring soon (e.g., 30s, within default 60s grace period)
|
||||
exp := time.Now().Add(30 * time.Second).Unix()
|
||||
iat := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
if refreshToken != "valid-refresh-token-for-near-expiry" {
|
||||
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
|
||||
}
|
||||
// Simulate successful refresh
|
||||
newToken := createNewValidToken()
|
||||
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-near-expiry", ExpiresIn: 3600}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusOK, // Expect success after proactive refresh
|
||||
expectedBody: "OK",
|
||||
},
|
||||
// This test case remains valid as no refresh should be attempted
|
||||
{
|
||||
name: "Authenticated request with token valid (outside grace period)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
// Create token expiring later (e.g., 10 mins, outside default 60s grace period)
|
||||
exp := time.Now().Add(10 * time.Minute).Unix()
|
||||
iat := time.Now().Add(-1 * time.Minute).Unix()
|
||||
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
||||
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(validToken)
|
||||
session.SetRefreshToken("should-not-be-used-refresh-token")
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
// This should NOT be called
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
t.Errorf("Refresh token function was called unexpectedly for valid token outside grace period")
|
||||
return nil, fmt.Errorf("refresh should not have been attempted")
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusOK, // Expect success, no refresh needed
|
||||
expectedBody: "OK",
|
||||
},
|
||||
{
|
||||
name: "Disallowed Domain (Accept: JSON)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
session.SetAccessToken(freshToken)
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
},
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`,
|
||||
},
|
||||
{
|
||||
name: "Disallowed Domain (Accept: HTML)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
session.SetAccessToken(freshToken)
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
},
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "text/html",
|
||||
},
|
||||
expectedStatus: http.StatusForbidden, // Still Forbidden, but HTML response
|
||||
expectedBody: "", // Body check is harder for HTML, focus on status and content-type
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", tc.requestPath, nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
req.Header.Set("X-Forwarded-Host", "localhost")
|
||||
// Set common headers needed by the logic (determineScheme, determineHost)
|
||||
req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that
|
||||
req.Header.Set("X-Forwarded-Host", "testhost.com")
|
||||
req.Host = "testhost.com" // Also set Host header
|
||||
// Set request headers from test case
|
||||
if tc.requestHeaders != nil {
|
||||
for key, value := range tc.requestHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Setup session if needed
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
t.Fatalf("Test %s: Failed to get initial session: %v", tc.name, err)
|
||||
}
|
||||
if tc.setupSession != nil {
|
||||
tc.setupSession(session)
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
// Save session to recorder to get cookies
|
||||
saveRecorder := httptest.NewRecorder()
|
||||
if err := session.Save(req, saveRecorder); err != nil {
|
||||
t.Fatalf("Test %s: Failed to save initial session: %v", tc.name, err)
|
||||
}
|
||||
|
||||
// Copy cookies to the new request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
// Copy cookies from save recorder to the actual request
|
||||
for _, cookie := range saveRecorder.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
rr = httptest.NewRecorder()
|
||||
}
|
||||
|
||||
// Mocking setup for TokenExchanger
|
||||
originalExchanger := ts.tOidc.tokenExchanger // Store original
|
||||
mockExchanger, isMock := originalExchanger.(*MockTokenExchanger)
|
||||
if !isMock {
|
||||
// This case should ideally not happen if Setup correctly assigns the mock,
|
||||
// but handle it defensively.
|
||||
t.Logf("Warning: Default exchanger was not the mock. Creating a temporary mock.")
|
||||
mockExchanger = &MockTokenExchanger{
|
||||
ExchangeCodeFunc: originalExchanger.ExchangeCodeForToken,
|
||||
RefreshTokenFunc: originalExchanger.GetNewTokenWithRefreshToken,
|
||||
RevokeTokenFunc: originalExchanger.RevokeTokenWithProvider,
|
||||
}
|
||||
ts.tOidc.tokenExchanger = mockExchanger // Temporarily assign mock
|
||||
}
|
||||
|
||||
// Override specific mock methods if needed for the test case
|
||||
originalMockRefreshFunc := mockExchanger.RefreshTokenFunc // Store current mock func
|
||||
if tc.mockRefreshTokenFunc != nil {
|
||||
// Assign the test case specific mock function
|
||||
mockExchanger.RefreshTokenFunc = tc.mockRefreshTokenFunc(originalExchanger.GetNewTokenWithRefreshToken)
|
||||
}
|
||||
|
||||
// Call ServeHTTP
|
||||
ts.tOidc.ServeHTTP(rr, req)
|
||||
|
||||
// Check response
|
||||
if rr.Code != tc.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
||||
// Restore original exchanger and mock function state
|
||||
ts.tOidc.tokenExchanger = originalExchanger
|
||||
if tc.mockRefreshTokenFunc != nil && mockExchanger != nil {
|
||||
// Restore the previous mock function if we overrode it
|
||||
mockExchanger.RefreshTokenFunc = originalMockRefreshFunc
|
||||
}
|
||||
|
||||
// Check response status
|
||||
if rr.Code != tc.expectedStatus {
|
||||
t.Errorf("Test %s: Expected status %d, got %d. Body: %s", tc.name, tc.expectedStatus, rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
// Check response body if expected
|
||||
// Check response body if expected (handle JSON vs HTML)
|
||||
if tc.expectedBody != "" {
|
||||
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
||||
t.Errorf("Expected body %q, got %q", tc.expectedBody, body)
|
||||
// For JSON, compare directly
|
||||
if strings.Contains(rr.Header().Get("Content-Type"), "application/json") {
|
||||
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
||||
t.Errorf("Test %s: Expected JSON body %q, got %q", tc.name, tc.expectedBody, body)
|
||||
}
|
||||
} else if tc.expectedBody == "OK" { // Simple check for the "OK" body from next handler
|
||||
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
||||
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
|
||||
}
|
||||
}
|
||||
// Add more sophisticated HTML body checks if needed
|
||||
}
|
||||
|
||||
// Perform post-request session assertions if defined
|
||||
if tc.assertSessionAfterRequest != nil {
|
||||
tc.assertSessionAfterRequest(t, rr, req, ts.tOidc.sessionManager)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -552,17 +949,39 @@ func TestHandleCallback(t *testing.T) {
|
||||
name: "Disallowed Email",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
||||
// Generate a unique token for this test case to avoid replay issues
|
||||
// Use claims relevant to this test (disallowed email)
|
||||
now := time.Now()
|
||||
exp := now.Add(1 * time.Hour).Unix()
|
||||
iat := now.Unix()
|
||||
nbf := now.Unix()
|
||||
disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": nbf,
|
||||
"sub": "test-subject-disallowed",
|
||||
"email": "user@disallowed.com", // The disallowed email for this test
|
||||
"nonce": "test-nonce", // Match the nonce set in sessionSetupFunc
|
||||
"jti": generateRandomString(16), // Unique JTI
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create disallowed token for test: %w", err)
|
||||
}
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
}, nil
|
||||
},
|
||||
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@disallowed.com",
|
||||
"nonce": "test-nonce",
|
||||
IDToken: disallowedToken,
|
||||
RefreshToken: "test-refresh-token-disallowed",
|
||||
}, nil
|
||||
},
|
||||
// Remove mock extractClaimsFunc - let the real one parse the disallowedToken
|
||||
// The test should still fail correctly on the email check later.
|
||||
// extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
||||
// return map[string]interface{}{
|
||||
// "email": "user@disallowed.com",
|
||||
// "nonce": "test-nonce",
|
||||
// }, nil
|
||||
// },
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
@@ -635,20 +1054,61 @@ func TestHandleCallback(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // Capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Clear the global replay cache before each test run
|
||||
replayCacheMu.Lock()
|
||||
replayCache = make(map[string]time.Time) // Reset the global cache
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Explicitly clear the shared blacklist at the start of each sub-test
|
||||
// to ensure no state leaks, even though we expect the local one to be used.
|
||||
// Note: This line might be redundant now that the verifier is local, but keep for safety.
|
||||
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
||||
|
||||
logger := NewLogger("info")
|
||||
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||
|
||||
// Create a new instance for each test to avoid state carryover
|
||||
tOidc := &TraefikOidc{
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
logger: logger,
|
||||
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
|
||||
extractClaimsFunc: tc.extractClaimsFunc,
|
||||
tokenVerifier: ts.tOidc.tokenVerifier,
|
||||
jwtVerifier: ts.tOidc.jwtVerifier,
|
||||
sessionManager: sessionManager,
|
||||
instanceExtractClaimsFunc := tc.extractClaimsFunc
|
||||
if instanceExtractClaimsFunc == nil {
|
||||
instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case
|
||||
}
|
||||
tOidc := &TraefikOidc{
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
logger: logger,
|
||||
// exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field
|
||||
extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function
|
||||
tokenVerifier: nil, // Will be set to self below
|
||||
jwtVerifier: nil, // Temporarily nil, will be set below
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: &MockTokenExchanger{ // Create a new mock exchanger for this specific test run
|
||||
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
// Wrap the test case function to match the required signature
|
||||
if tc.exchangeCodeForToken != nil {
|
||||
// Only call if the test case provided a function
|
||||
return tc.exchangeCodeForToken(codeOrToken, redirectURL, codeVerifier)
|
||||
}
|
||||
// Provide a default behavior or error if no mock was provided for this test case
|
||||
return nil, fmt.Errorf("mock ExchangeCodeFunc not implemented for this test case")
|
||||
},
|
||||
// Keep other mock funcs nil or provide defaults if needed by other parts of handleCallback
|
||||
},
|
||||
tokenCache: NewTokenCache(), // Initialize token cache
|
||||
limiter: rate.NewLimiter(rate.Inf, 0), // Initialize rate limiter
|
||||
tokenBlacklist: NewCache(), // Initialize token blacklist cache
|
||||
|
||||
// Add potentially missing fields based on New() comparison
|
||||
clientID: ts.tOidc.clientID,
|
||||
issuerURL: ts.tOidc.issuerURL,
|
||||
jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite
|
||||
httpClient: ts.tOidc.httpClient,
|
||||
initComplete: make(chan struct{}), // Initialize the channel
|
||||
// Setting other fields like paths, enablePKCE etc. if needed
|
||||
}
|
||||
tOidc.tokenVerifier = tOidc // Point tokenVerifier to the local instance NOW
|
||||
tOidc.jwtVerifier = tOidc // Point jwtVerifier to the local instance NOW
|
||||
close(tOidc.initComplete) // Mark this test instance as initialized
|
||||
|
||||
// Create request and response recorder
|
||||
req := httptest.NewRequest("GET", "/callback"+tc.queryParams, nil)
|
||||
@@ -839,13 +1299,14 @@ func TestOIDCHandler(t *testing.T) {
|
||||
tc := tc // Capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Reset token blacklist and cache
|
||||
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
|
||||
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
||||
ts.tOidc.tokenCache = NewTokenCache()
|
||||
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
||||
|
||||
// Set up the test case
|
||||
if tc.blacklist {
|
||||
ts.tOidc.tokenBlacklist.Add(ts.token, time.Now().Add(1*time.Hour))
|
||||
// Use Set with a duration. Value 'true' is arbitrary.
|
||||
ts.tOidc.tokenBlacklist.Set(ts.token, true, 1*time.Hour)
|
||||
}
|
||||
|
||||
if tc.rateLimit {
|
||||
@@ -948,7 +1409,7 @@ func TestHandleLogout(t *testing.T) {
|
||||
endSessionURL: tc.endSessionURL,
|
||||
scheme: "http",
|
||||
logger: logger,
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
||||
httpClient: &http.Client{},
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
@@ -1018,13 +1479,13 @@ func TestHandleLogout(t *testing.T) {
|
||||
|
||||
// Check token blacklist
|
||||
if token := session.GetAccessToken(); token != "" {
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
|
||||
t.Error("Access token was not blacklisted")
|
||||
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
||||
t.Error("Access token was not blacklisted in cache")
|
||||
}
|
||||
}
|
||||
if token := session.GetRefreshToken(); token != "" {
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
|
||||
t.Error("Refresh token was not blacklisted")
|
||||
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
||||
t.Error("Refresh token was not blacklisted in cache")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1121,8 +1582,9 @@ func TestRevokeToken(t *testing.T) {
|
||||
t.Run("Token revocation", func(t *testing.T) {
|
||||
// Create a new instance for this specific test
|
||||
tOidc := &TraefikOidc{
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
||||
tokenCache: NewTokenCache(),
|
||||
logger: NewLogger("info"), // Initialize the logger
|
||||
}
|
||||
|
||||
// Cache the token
|
||||
@@ -1136,8 +1598,8 @@ func TestRevokeToken(t *testing.T) {
|
||||
t.Error("Token was not removed from cache")
|
||||
}
|
||||
|
||||
// Verify token was added to blacklist
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
|
||||
// Verify token was added to blacklist cache
|
||||
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
||||
t.Error("Token was not added to blacklist")
|
||||
}
|
||||
})
|
||||
@@ -1404,7 +1866,6 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}), config, "test")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
|
||||
}
|
||||
@@ -2043,7 +2504,6 @@ func TestExchangeCodeForToken(t *testing.T) {
|
||||
|
||||
// Test exchangeCodeForToken
|
||||
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
@@ -2082,3 +2542,130 @@ func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
||||
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyTimeConstraint tests the time constraint verification logic with separate past/future skew tolerances.
|
||||
func TestVerifyTimeConstraint(t *testing.T) {
|
||||
// Define tolerances used in jwt.go (ensure they match)
|
||||
toleranceFuture := 2 * time.Minute
|
||||
tolerancePast := 10 * time.Second
|
||||
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claimTime time.Time
|
||||
claimName string
|
||||
futureCheck bool // true for exp, false for iat/nbf
|
||||
expectError bool
|
||||
}{
|
||||
// Expiration (future=true, tolerance=2min)
|
||||
{
|
||||
name: "EXP: Valid (expires in 1 min)",
|
||||
claimTime: now.Add(1 * time.Minute),
|
||||
claimName: "exp",
|
||||
futureCheck: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "EXP: Expired (expired 3 min ago)",
|
||||
claimTime: now.Add(-3 * time.Minute), // Outside 2min tolerance
|
||||
claimName: "exp",
|
||||
futureCheck: true,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "EXP: Valid (expired 1 min ago, within 2min tolerance)",
|
||||
claimTime: now.Add(-1 * time.Minute), // Inside 2min tolerance
|
||||
claimName: "exp",
|
||||
futureCheck: true,
|
||||
expectError: false, // Should be allowed due to future tolerance
|
||||
},
|
||||
|
||||
// Issued At (future=false, tolerance=10s)
|
||||
{
|
||||
name: "IAT: Valid (issued 1 min ago)",
|
||||
claimTime: now.Add(-1 * time.Minute),
|
||||
claimName: "iat",
|
||||
futureCheck: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "IAT: Invalid (issued 15 sec in future)",
|
||||
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
|
||||
claimName: "iat",
|
||||
futureCheck: false,
|
||||
expectError: true, // "token used before issued"
|
||||
},
|
||||
{
|
||||
name: "IAT: Valid (issued 5 sec in future, within 10s tolerance)",
|
||||
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
|
||||
claimName: "iat",
|
||||
futureCheck: false,
|
||||
expectError: false, // Should be allowed due to past tolerance
|
||||
},
|
||||
|
||||
// Not Before (future=false, tolerance=10s)
|
||||
{
|
||||
name: "NBF: Valid (active 1 min ago)",
|
||||
claimTime: now.Add(-1 * time.Minute),
|
||||
claimName: "nbf",
|
||||
futureCheck: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "NBF: Invalid (active in 15 sec)",
|
||||
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
|
||||
claimName: "nbf",
|
||||
futureCheck: false,
|
||||
expectError: true, // "token not yet valid"
|
||||
},
|
||||
{
|
||||
name: "NBF: Valid (active in 5 sec, within 10s tolerance)",
|
||||
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
|
||||
claimName: "nbf",
|
||||
futureCheck: false,
|
||||
expectError: false, // Should be allowed due to past tolerance
|
||||
},
|
||||
}
|
||||
|
||||
// Temporarily adjust global tolerances for test consistency, then restore
|
||||
originalFutureTolerance := ClockSkewToleranceFuture
|
||||
originalPastTolerance := ClockSkewTolerancePast
|
||||
ClockSkewToleranceFuture = toleranceFuture
|
||||
ClockSkewTolerancePast = tolerancePast
|
||||
defer func() {
|
||||
ClockSkewToleranceFuture = originalFutureTolerance
|
||||
ClockSkewTolerancePast = originalPastTolerance
|
||||
}()
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Convert claim time to float64 unix timestamp
|
||||
unixTime := float64(tc.claimTime.Unix()) + float64(tc.claimTime.Nanosecond())/1e9
|
||||
|
||||
var err error
|
||||
// Call the specific verification function which uses verifyTimeConstraint
|
||||
if tc.claimName == "exp" {
|
||||
err = verifyExpiration(unixTime)
|
||||
} else if tc.claimName == "iat" {
|
||||
err = verifyIssuedAt(unixTime)
|
||||
} else if tc.claimName == "nbf" {
|
||||
err = verifyNotBefore(unixTime)
|
||||
} else {
|
||||
t.Fatalf("Unknown claim name in test setup: %s", tc.claimName)
|
||||
}
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for claim %s at time %v (now=%v), but got nil", tc.claimName, tc.claimTime, now)
|
||||
} else {
|
||||
t.Logf("Got expected error: %v", err) // Log the error for confirmation
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for claim %s at time %v (now=%v), but got: %v", tc.claimName, tc.claimTime, now, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
} // Add missing closing brace for TestVerifyTimeConstraint
|
||||
|
||||
+32
-11
@@ -15,6 +15,8 @@ type MetadataCache struct {
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// NewMetadataCache creates a new MetadataCache instance.
|
||||
// It initializes the cache structure and starts the background cleanup goroutine.
|
||||
func NewMetadataCache() *MetadataCache {
|
||||
c := &MetadataCache{
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
@@ -24,7 +26,8 @@ func NewMetadataCache() *MetadataCache {
|
||||
return c
|
||||
}
|
||||
|
||||
// Cleanup removes expired metadata from the cache.
|
||||
// Cleanup removes the cached provider metadata if it has expired.
|
||||
// This is called periodically by the auto-cleanup goroutine.
|
||||
func (c *MetadataCache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
@@ -34,11 +37,32 @@ func (c *MetadataCache) Cleanup() {
|
||||
c.metadata = nil
|
||||
}
|
||||
}
|
||||
|
||||
// isCacheValid checks if the cached metadata is present and has not expired.
|
||||
// Note: This function assumes the read lock is held or it's called from a context
|
||||
// where the lock is already held (like within GetMetadata after locking).
|
||||
func (c *MetadataCache) isCacheValid() bool {
|
||||
return c.metadata != nil && time.Now().Before(c.expiresAt)
|
||||
}
|
||||
|
||||
// GetMetadata retrieves the metadata from cache or fetches it if expired
|
||||
// GetMetadata retrieves the OIDC provider metadata.
|
||||
// It first checks the cache for valid, non-expired metadata. If found, it's returned immediately.
|
||||
// If the cache is empty or expired, it attempts to fetch the metadata from the provider's
|
||||
// well-known endpoint using discoverProviderMetadata.
|
||||
// If fetching is successful, the new metadata is cached for 1 hour.
|
||||
// If fetching fails but valid metadata exists in the cache (even if expired), the cache expiry
|
||||
// is extended by 5 minutes, and the cached data is returned to prevent thundering herd issues.
|
||||
// If fetching fails and there's no cached data, an error is returned.
|
||||
// It employs double-checked locking for thread safety and performance.
|
||||
//
|
||||
// Parameters:
|
||||
// - providerURL: The base URL of the OIDC provider.
|
||||
// - httpClient: The HTTP client to use for fetching metadata.
|
||||
// - logger: The logger instance for recording errors or warnings.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the ProviderMetadata struct.
|
||||
// - An error if metadata cannot be retrieved from cache or fetched from the provider.
|
||||
func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) {
|
||||
c.mutex.RLock()
|
||||
if c.isCacheValid() {
|
||||
@@ -67,24 +91,21 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
|
||||
}
|
||||
|
||||
c.metadata = metadata
|
||||
// Calculate expiration time based on usage patterns
|
||||
usageCount := 0 // This should be replaced with actual usage tracking logic
|
||||
if usageCount < 10 {
|
||||
c.expiresAt = time.Now().Add(30 * time.Minute)
|
||||
} else if usageCount < 50 {
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
} else {
|
||||
c.expiresAt = time.Now().Add(2 * time.Hour)
|
||||
}
|
||||
// Set a fixed cache lifetime (e.g., 1 hour)
|
||||
// TODO: Consider making this configurable or respecting HTTP cache headers
|
||||
c.expiresAt = time.Now().Add(1 * time.Hour)
|
||||
|
||||
// End of GetMetadata
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// startAutoCleanup starts the background goroutine that periodically calls Cleanup
|
||||
// to remove expired metadata from the cache.
|
||||
func (c *MetadataCache) startAutoCleanup() {
|
||||
autoCleanupRoutine(c.autoCleanupInterval, c.stopCleanup, c.Cleanup)
|
||||
}
|
||||
|
||||
// Close stops the automatic cleanup goroutine associated with this metadata cache.
|
||||
func (c *MetadataCache) Close() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
+191
-42
@@ -16,8 +16,15 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// generateSecureRandomString creates a cryptographically secure random string of specified length.
|
||||
// It returns the generated string or an error if random generation fails.
|
||||
// generateSecureRandomString creates a cryptographically secure, hex-encoded random string.
|
||||
// It reads the specified number of bytes from crypto/rand and encodes them as a hexadecimal string.
|
||||
//
|
||||
// Parameters:
|
||||
// - length: The number of random bytes to generate (the resulting hex string will be twice this length).
|
||||
//
|
||||
// Returns:
|
||||
// - A hex-encoded random string.
|
||||
// - An error if reading random bytes fails.
|
||||
func generateSecureRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
@@ -56,7 +63,14 @@ const (
|
||||
minEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// compressToken compresses a token using gzip and base64 encodes it.
|
||||
// compressToken compresses the input string using gzip and then encodes the result using standard base64 encoding.
|
||||
// If any error occurs during compression, it returns the original uncompressed token as a fallback.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The string to compress.
|
||||
//
|
||||
// Returns:
|
||||
// - The base64 encoded, gzipped string, or the original string if compression fails.
|
||||
func compressToken(token string) string {
|
||||
var b bytes.Buffer
|
||||
gz := gzip.NewWriter(&b)
|
||||
@@ -69,7 +83,15 @@ func compressToken(token string) string {
|
||||
return base64.StdEncoding.EncodeToString(b.Bytes())
|
||||
}
|
||||
|
||||
// decompressToken decompresses a base64 encoded gzipped token.
|
||||
// decompressToken decodes a standard base64 encoded string and then decompresses the result using gzip.
|
||||
// If base64 decoding or gzip decompression fails, it returns the original input string as a fallback,
|
||||
// assuming it might not have been compressed.
|
||||
//
|
||||
// Parameters:
|
||||
// - compressed: The base64 encoded, gzipped string.
|
||||
//
|
||||
// Returns:
|
||||
// - The decompressed original string, or the input string if decompression fails.
|
||||
func decompressToken(compressed string) string {
|
||||
data, err := base64.StdEncoding.DecodeString(compressed)
|
||||
if err != nil {
|
||||
@@ -128,25 +150,27 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
|
||||
|
||||
// Initialize session pool.
|
||||
sm.sessionPool.New = func() interface{} {
|
||||
// Initialize SessionData with necessary fields and the mutex.
|
||||
return &SessionData{
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshMutex: sync.Mutex{}, // Initialize the mutex
|
||||
}
|
||||
}
|
||||
|
||||
return sm, nil
|
||||
}
|
||||
|
||||
// getSessionOptions returns secure session options configured for the current request.
|
||||
// Parameters:
|
||||
// - isSecure: Whether the current request is using HTTPS.
|
||||
// getSessionOptions returns a sessions.Options struct configured with security best practices.
|
||||
// It sets HttpOnly to true, Secure based on the request scheme or forceHTTPS setting,
|
||||
// SameSite to LaxMode, MaxAge to the absoluteSessionTimeout, and Path to "/".
|
||||
//
|
||||
// The options ensure cookies are:
|
||||
// - HTTP-only (not accessible via JavaScript)
|
||||
// - Secure when using HTTPS or when forceHTTPS is enabled
|
||||
// - Using SameSite=Lax for CSRF protection
|
||||
// - Set with appropriate timeout and path settings
|
||||
// Parameters:
|
||||
// - isSecure: A boolean indicating if the current request context is secure (HTTPS).
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a configured sessions.Options struct.
|
||||
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
@@ -208,11 +232,14 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// getTokenChunkSessions retrieves all session chunks for a given token type.
|
||||
// getTokenChunkSessions retrieves all cookie chunks associated with a large token (access or refresh).
|
||||
// It iteratively attempts to load cookies named "{baseName}_0", "{baseName}_1", etc., until
|
||||
// a cookie is not found or returns an error. The loaded sessions are stored in the provided chunks map.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request
|
||||
// - baseName: The base name for the token's session cookies
|
||||
// - chunks: Map to store the chunks in
|
||||
// - r: The incoming HTTP request containing the cookies.
|
||||
// - baseName: The base name of the cookie (e.g., accessTokenCookie).
|
||||
// - chunks: The map (typically SessionData.accessTokenChunks or SessionData.refreshTokenChunks) to populate with the found session chunks.
|
||||
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||
@@ -251,12 +278,21 @@ type SessionData struct {
|
||||
// refreshTokenChunks stores additional chunks of the refresh token
|
||||
// when it exceeds the maximum cookie size.
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshMutex protects refresh token operations within this session instance.
|
||||
refreshMutex sync.Mutex
|
||||
}
|
||||
|
||||
// Save persists all session data to cookies in the HTTP response.
|
||||
// It saves the main session, token sessions, and any token chunks,
|
||||
// applying appropriate security options to each cookie. All cookies
|
||||
// are saved with consistent security settings based on the request scheme.
|
||||
// Save persists all parts of the session (main, access token, refresh token, and any chunks)
|
||||
// back to the client as cookies in the HTTP response. It applies secure cookie options
|
||||
// obtained via getSessionOptions based on the request's security context.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The original HTTP request (used to determine security context for cookie options).
|
||||
// - w: The HTTP response writer to which the Set-Cookie headers will be added.
|
||||
//
|
||||
// Returns:
|
||||
// - An error if saving any of the session components fails.
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
@@ -300,7 +336,19 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all session data by expiring all cookies and clearing their values.
|
||||
// Clear removes all session data associated with this SessionData instance.
|
||||
// It clears the values map of the main, access, and refresh sessions, sets their MaxAge to -1
|
||||
// to expire the cookies immediately, and clears any associated token chunk cookies.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired sessions to send the
|
||||
// expiring Set-Cookie headers. Finally, it clears internal fields and returns the SessionData
|
||||
// object to the pool.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request (required by the underlying session store).
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
//
|
||||
// Returns:
|
||||
// - An error if saving the expired sessions fails (only if w is not nil).
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions.
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
@@ -335,7 +383,12 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// clearTokenChunks removes all session chunks for a given token type.
|
||||
// clearTokenChunks iterates through a map of session chunks, clears their values,
|
||||
// and sets their MaxAge to -1 to expire them. This is used internally by Clear.
|
||||
//
|
||||
// Parameters:
|
||||
// - r: The HTTP request (required by the underlying session store, though not directly used here).
|
||||
// - chunks: The map of session chunks (e.g., sd.accessTokenChunks) to clear and expire.
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
@@ -345,7 +398,12 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthenticated returns whether the current session is authenticated.
|
||||
// GetAuthenticated checks if the session is marked as authenticated and has not exceeded
|
||||
// the absolute session timeout.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the "authenticated" flag is set to true and the session creation time is within the allowed timeout.
|
||||
// - false otherwise.
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
if !auth {
|
||||
@@ -360,8 +418,15 @@ func (sd *SessionData) GetAuthenticated() bool {
|
||||
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
|
||||
}
|
||||
|
||||
// SetAuthenticated updates the session's authentication status and rotates session ID.
|
||||
// Returns an error if generating a new session ID fails.
|
||||
// SetAuthenticated sets the authentication status of the session.
|
||||
// If setting to true, it generates a new secure session ID for the main session
|
||||
// to prevent session fixation attacks and records the current time as the creation time.
|
||||
//
|
||||
// Parameters:
|
||||
// - value: The boolean authentication status (true for authenticated, false otherwise).
|
||||
//
|
||||
// Returns:
|
||||
// - An error if generating a new session ID fails when setting value to true.
|
||||
func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
if value {
|
||||
id, err := generateSecureRandomString(32)
|
||||
@@ -375,7 +440,12 @@ func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccessToken retrieves the complete access token from the session.
|
||||
// GetAccessToken retrieves the access token stored in the session.
|
||||
// It handles reassembling the token from multiple cookie chunks if necessary
|
||||
// and decompresses it if it was stored compressed.
|
||||
//
|
||||
// Returns:
|
||||
// - The complete, decompressed access token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -409,7 +479,14 @@ func (sd *SessionData) GetAccessToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// SetAccessToken stores the access token in the session.
|
||||
// SetAccessToken stores the provided access token in the session.
|
||||
// It first expires any existing access token chunk cookies.
|
||||
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
|
||||
// it's stored directly in the primary access token session. Otherwise, the compressed token
|
||||
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_a_0, _oidc_raczylo_a_1, etc.).
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The access token string to store.
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
@@ -439,7 +516,12 @@ func (sd *SessionData) SetAccessToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetRefreshToken retrieves the complete refresh token from the session.
|
||||
// GetRefreshToken retrieves the refresh token stored in the session.
|
||||
// It handles reassembling the token from multiple cookie chunks if necessary
|
||||
// and decompresses it if it was stored compressed.
|
||||
//
|
||||
// Returns:
|
||||
// - The complete, decompressed refresh token string, or an empty string if not found.
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
@@ -473,7 +555,14 @@ func (sd *SessionData) GetRefreshToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// SetRefreshToken stores the refresh token in the session.
|
||||
// SetRefreshToken stores the provided refresh token in the session.
|
||||
// It first expires any existing refresh token chunk cookies.
|
||||
// It then compresses the token. If the compressed token fits within a single cookie (maxCookieSize),
|
||||
// it's stored directly in the primary refresh token session. Otherwise, the compressed token
|
||||
// is split into chunks, and each chunk is stored in a separate numbered cookie (_oidc_raczylo_r_0, _oidc_raczylo_r_1, etc.).
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The refresh token string to store.
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Expire any existing chunk cookies first.
|
||||
if sd.request != nil {
|
||||
@@ -503,7 +592,13 @@ func (sd *SessionData) SetRefreshToken(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireAccessTokenChunks expires any existing access token chunk cookies.
|
||||
// expireAccessTokenChunks finds all existing access token chunk cookies (_oidc_raczylo_a_N)
|
||||
// associated with the current request, clears their values, and sets their MaxAge to -1.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
|
||||
// the expiring Set-Cookie headers. This is used internally when setting a new access token.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
@@ -521,7 +616,13 @@ func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
|
||||
// expireRefreshTokenChunks finds all existing refresh token chunk cookies (_oidc_raczylo_r_N)
|
||||
// associated with the current request, clears their values, and sets their MaxAge to -1.
|
||||
// If a ResponseWriter is provided, it attempts to save the expired chunk sessions to send
|
||||
// the expiring Set-Cookie headers. This is used internally when setting a new refresh token.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The HTTP response writer (optional). If provided, expiring Set-Cookie headers will be sent.
|
||||
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
@@ -539,7 +640,15 @@ func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size.
|
||||
// splitIntoChunks divides a string `s` into a slice of strings, where each element
|
||||
// has a maximum length of `chunkSize`.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The string to split.
|
||||
// - chunkSize: The maximum size of each chunk.
|
||||
//
|
||||
// Returns:
|
||||
// - A slice of strings representing the chunks.
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
@@ -554,57 +663,97 @@ func splitIntoChunks(s string, chunkSize int) []string {
|
||||
return chunks
|
||||
}
|
||||
|
||||
// GetCSRF retrieves the CSRF token from the session.
|
||||
// GetCSRF retrieves the Cross-Site Request Forgery (CSRF) token stored in the main session.
|
||||
//
|
||||
// Returns:
|
||||
// - The CSRF token string, or an empty string if not set.
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF stores a new CSRF token in the session.
|
||||
// SetCSRF stores the provided CSRF token string in the main session.
|
||||
// This token is typically generated at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The CSRF token to store.
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce retrieves the nonce value from the session.
|
||||
// GetNonce retrieves the OIDC nonce value stored in the main session.
|
||||
// The nonce is used to associate an ID token with the specific authentication request.
|
||||
//
|
||||
// Returns:
|
||||
// - The nonce string, or an empty string if not set.
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce stores a new nonce value in the session.
|
||||
// SetNonce stores the provided OIDC nonce string in the main session.
|
||||
// This nonce is typically generated at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - nonce: The nonce string to store.
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetCodeVerifier retrieves the PKCE code verifier from the session.
|
||||
// GetCodeVerifier retrieves the PKCE (Proof Key for Code Exchange) code verifier
|
||||
// stored in the main session. This is only relevant if PKCE is enabled.
|
||||
//
|
||||
// Returns:
|
||||
// - The code verifier string, or an empty string if not set or PKCE is disabled.
|
||||
func (sd *SessionData) GetCodeVerifier() string {
|
||||
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
||||
return codeVerifier
|
||||
}
|
||||
|
||||
// SetCodeVerifier stores the PKCE code verifier in the session.
|
||||
// SetCodeVerifier stores the provided PKCE code verifier string in the main session.
|
||||
// This is typically called at the start of the authentication flow if PKCE is enabled.
|
||||
//
|
||||
// Parameters:
|
||||
// - codeVerifier: The PKCE code verifier string to store.
|
||||
func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
sd.mainSession.Values["code_verifier"] = codeVerifier
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address from the session.
|
||||
// GetEmail retrieves the authenticated user's email address stored in the main session.
|
||||
// This is typically extracted from the ID token claims after successful authentication.
|
||||
//
|
||||
// Returns:
|
||||
// - The user's email address string, or an empty string if not set.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail stores the user's email address in the session.
|
||||
// SetEmail stores the provided user email address string in the main session.
|
||||
// This is typically called after successful authentication and claim extraction.
|
||||
//
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
|
||||
// GetIncomingPath retrieves the original request URI (including query parameters)
|
||||
// that the user was trying to access before being redirected for authentication.
|
||||
// This is stored in the main session to allow redirection back after successful login.
|
||||
//
|
||||
// Returns:
|
||||
// - The original request URI string, or an empty string if not set.
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath stores the original request path that triggered the authentication flow.
|
||||
// SetIncomingPath stores the original request URI (path and query parameters)
|
||||
// in the main session. This is typically called at the start of the authentication flow.
|
||||
//
|
||||
// Parameters:
|
||||
// - path: The original request URI string (e.g., "/protected/resource?id=123").
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
|
||||
+9
-2
@@ -1,7 +1,9 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -12,7 +14,12 @@ func generateRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[rand.Intn(len(charset))]
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
// Handle error appropriately in a real application, maybe panic in test helper
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
+107
-38
@@ -84,6 +84,11 @@ type Config struct {
|
||||
|
||||
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
|
||||
HTTPClient *http.Client
|
||||
|
||||
// RefreshGracePeriodSeconds defines how many seconds before a token expires
|
||||
// the plugin should attempt to refresh it proactively (optional)
|
||||
// Default: 60
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -109,21 +114,35 @@ const (
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
// - ForceHTTPS: true (for security)
|
||||
// - EnablePKCE: false (PKCE is opt-in)
|
||||
//
|
||||
// CreateConfig initializes a new Config struct with default values for optional fields.
|
||||
// It sets default scopes, log level, rate limit, enables ForceHTTPS, and sets the
|
||||
// default refresh grace period. Required fields like ProviderURL, ClientID, ClientSecret,
|
||||
// CallbackURL, and SessionEncryptionKey must be set explicitly after creation.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a new Config struct with default settings applied.
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate performs validation checks on the Config.
|
||||
// It ensures all required fields are set and have valid values.
|
||||
// Returns an error if any validation check fails.
|
||||
// Validate checks the configuration settings for validity.
|
||||
// It ensures that required fields (ProviderURL, CallbackURL, ClientID, ClientSecret, SessionEncryptionKey)
|
||||
// are present and that URLs are well-formed (HTTPS where required). It also validates
|
||||
// the session key length, log level, rate limit, and refresh grace period.
|
||||
//
|
||||
// Returns:
|
||||
// - nil if the configuration is valid.
|
||||
// - An error describing the first validation failure encountered.
|
||||
func (c *Config) Validate() error {
|
||||
// Validate provider URL
|
||||
if c.ProviderURL == "" {
|
||||
@@ -197,16 +216,34 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
|
||||
}
|
||||
|
||||
// Validate refresh grace period
|
||||
if c.RefreshGracePeriodSeconds < 0 {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidSecureURL checks if the provided string is a valid HTTPS URL
|
||||
// isValidSecureURL checks if a given string represents a valid, absolute HTTPS URL.
|
||||
// It uses url.Parse and checks for a nil error, an "https" scheme, and a non-empty host.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The URL string to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the string is a valid HTTPS URL, false otherwise.
|
||||
func isValidSecureURL(s string) bool {
|
||||
u, err := url.Parse(s)
|
||||
return err == nil && u.Scheme == "https" && u.Host != ""
|
||||
}
|
||||
|
||||
// isValidLogLevel checks if the provided log level is valid
|
||||
// isValidLogLevel checks if the provided log level string is one of the supported values ("debug", "info", "error").
|
||||
//
|
||||
// Parameters:
|
||||
// - level: The log level string to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the log level is valid, false otherwise.
|
||||
func isValidLogLevel(level string) bool {
|
||||
return level == "debug" || level == "info" || level == "error"
|
||||
}
|
||||
@@ -223,14 +260,20 @@ type Logger struct {
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger with the specified log level.
|
||||
// The log level determines which messages are output:
|
||||
// - "debug": Outputs all messages (debug, info, error)
|
||||
// - "info": Outputs info and error messages
|
||||
// - "error": Outputs only error messages
|
||||
// NewLogger creates and configures a new Logger instance based on the provided log level.
|
||||
// It initializes loggers for ERROR (stderr), INFO (stdout), and DEBUG (stdout) levels,
|
||||
// enabling output based on the specified level:
|
||||
// - "error": Only ERROR messages are output.
|
||||
// - "info": INFO and ERROR messages are output.
|
||||
// - "debug": DEBUG, INFO, and ERROR messages are output.
|
||||
//
|
||||
// Error messages are always written to stderr, while info and debug
|
||||
// messages are written to stdout when enabled.
|
||||
// If an invalid level is provided, it defaults to behavior similar to "error".
|
||||
//
|
||||
// Parameters:
|
||||
// - logLevel: The desired logging level ("debug", "info", or "error").
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to the configured Logger instance.
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
@@ -252,51 +295,77 @@ func NewLogger(logLevel string) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an informational message.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
// Info logs a message at the INFO level using Printf style formatting.
|
||||
// Output is directed to stdout if the configured log level is "info" or "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
// Debug logs a message at the DEBUG level using Printf style formatting.
|
||||
// Output is directed to stdout only if the configured log level is "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
// Error logs a message at the ERROR level using Printf style formatting.
|
||||
// Output is always directed to stderr, regardless of the configured log level.
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs an informational message using Printf formatting.
|
||||
// These messages are intended for general operational information
|
||||
// and are written to stdout.
|
||||
// Infof logs a message at the INFO level using Printf style formatting.
|
||||
// Equivalent to calling l.Info(format, args...).
|
||||
// Output is directed to stdout if the configured log level is "info" or "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message using Printf formatting.
|
||||
// These messages are only output when debug level logging is enabled
|
||||
// and are intended for detailed troubleshooting information.
|
||||
// Debugf logs a message at the DEBUG level using Printf style formatting.
|
||||
// Equivalent to calling l.Debug(format, args...).
|
||||
// Output is directed to stdout only if the configured log level is "debug".
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error message using Printf formatting.
|
||||
// These messages indicate problems that need attention and are
|
||||
// always written to stderr regardless of the log level.
|
||||
// Errorf logs a message at the ERROR level using Printf style formatting.
|
||||
// Equivalent to calling l.Error(format, args...).
|
||||
// Output is always directed to stderr, regardless of the configured log level.
|
||||
//
|
||||
// Parameters:
|
||||
// - format: The format string (as in fmt.Printf).
|
||||
// - args: The arguments for the format string.
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// handleError writes an error message to both the HTTP response and the error log.
|
||||
// It ensures consistent error handling across the middleware by logging the error
|
||||
// and sending an appropriate HTTP response to the client.
|
||||
// handleError logs an error message using the provided logger and sends an HTTP error
|
||||
// response to the client with the specified message and status code.
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The http.ResponseWriter to send the error response to.
|
||||
// - message: The error message string.
|
||||
// - code: The HTTP status code for the response.
|
||||
// - logger: The Logger instance to use for logging the error.
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error(message)
|
||||
http.Error(w, message, code)
|
||||
|
||||
Reference in New Issue
Block a user